diff options
Diffstat (limited to 'src/constrain.cpp')
-rw-r--r-- | src/constrain.cpp | 73 |
1 files changed, 43 insertions, 30 deletions
diff --git a/src/constrain.cpp b/src/constrain.cpp index 2fd3fa9..982f195 100644 --- a/src/constrain.cpp +++ b/src/constrain.cpp @@ -51,7 +51,7 @@ void ATuple::constrain(TEnv& tenv, Constraints& c) const { AType* t = tup<AType>(loc, NULL); - FOREACH(const_iterator, p, *this) { + FOREACHP(const_iterator, p, this) { (*p)->constrain(tenv, c); t->push_back(tenv.var(*p)); } @@ -76,10 +76,12 @@ AFn::constrain(TEnv& tenv, Constraints& c) const frame.push_back(make_pair(sym, tvar)); protT->push_back(tvar); } - c.constrain(tenv, at(1), protT); + + const_iterator i = begin() + 1; + c.constrain(tenv, *i, protT); // Add internal definitions to environment frame - for (const_iterator i = begin() + 2; i != end(); ++i) { + for (++i; i != end(); ++i) { const AST* exp = *i; const ADef* def = exp->to<const ADef*>(); if (def) { @@ -94,10 +96,11 @@ AFn::constrain(TEnv& tenv, Constraints& c) const tenv.push(frame); c.constrain(tenv, this, tenv.var()); - for (size_t i = 2; i < size(); ++i) - at(i)->constrain(tenv, c); + AST* exp = NULL; + for (i = begin() + 2; i != end(); ++i) + (exp = *i)->constrain(tenv, c); - AType* bodyT = tenv.var(at(size() - 1)); + AType* bodyT = tenv.var(exp); AType* fnT = tup<AType>(loc, tenv.penv.sym("Fn"), protT, bodyT, 0); Object::pool.addRoot(fnT); @@ -109,8 +112,8 @@ AFn::constrain(TEnv& tenv, Constraints& c) const void ACall::constrain(TEnv& tenv, Constraints& c) const { - for (size_t i = 0; i < size(); ++i) - at(i)->constrain(tenv, c); + for (const_iterator i = begin(); i != end(); ++i) + (*i)->constrain(tenv, c); const AType* fnType = tenv.var(head()); if (fnType->kind != AType::VAR) { @@ -120,15 +123,15 @@ ACall::constrain(TEnv& tenv, Constraints& c) const || fnType->head()->to<const ASymbol*>()->cppstr != "Fn") throw Error(loc, (format("call to non-function `%1%'") % head()->str()).str()); - size_t numArgs = fnType->at(1)->to<const ATuple*>()->size(); + size_t numArgs = fnType->prot()->size(); THROW_IF(numArgs != size() - 1, loc, (format("expected %1% arguments, got %2%") % numArgs % (size() - 1)).str()); } AType* retT = tenv.var(); AType* argsT = tup<AType>(loc, 0); - for (size_t i = 1; i < size(); ++i) - argsT->push_back(tenv.var(at(i))); + for (const_iterator i = begin() + 1; i != end(); ++i) + argsT->push_back(tenv.var(*i)); c.constrain(tenv, head(), tup<AType>(head()->loc, tenv.penv.sym("Fn"), argsT, retT, 0)); c.constrain(tenv, this, retT); @@ -141,10 +144,10 @@ ADef::constrain(TEnv& tenv, Constraints& c) const const ASymbol* sym = this->sym(); THROW_IF(!sym, loc, "`def' has no symbol") - AType* tvar = tenv.var(at(2)); + AType* tvar = tenv.var(body()); tenv.def(sym, tvar); - at(2)->constrain(tenv, c); - c.constrain(tenv, at(1), tvar); + body()->constrain(tenv, c); + c.constrain(tenv, sym, tvar); c.constrain(tenv, this, tenv.named("Nothing")); } @@ -153,16 +156,20 @@ AIf::constrain(TEnv& tenv, Constraints& c) const { THROW_IF(size() < 4, loc, "`if' requires at least 3 arguments"); THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause") - for (size_t i = 1; i < size(); ++i) - at(i)->constrain(tenv, c); + for (const_iterator i = begin() + 1; i != end(); ++i) + (*i)->constrain(tenv, c); AType* retT = tenv.var(this); - for (size_t i = 1; i < size(); i += 2) { - if (i == size() - 1) { - c.constrain(tenv, at(i), retT); + for (const_iterator i = begin() + 1; true; ++i) { + const_iterator next = i; + ++next; + if (next == end()) { // final (else) expression + c.constrain(tenv, *i, retT); + break; } else { - c.constrain(tenv, at(i), tenv.named("Bool")); - c.constrain(tenv, at(i+1), retT); + c.constrain(tenv, *i, tenv.named("Bool")); + c.constrain(tenv, *next, retT); } + i = next; // jump 2 each iteration (to the next predicate) } } @@ -182,34 +189,40 @@ APrimitive::constrain(TEnv& tenv, Constraints& c) const else throw Error(loc, (format("unknown primitive `%1%'") % n).str()); - for (size_t i = 1; i < size(); ++i) - at(i)->constrain(tenv, c); + const_iterator i = begin(); + + for (++i; i != end(); ++i) + (*i)->constrain(tenv, c); + + i = begin(); + AType* var = NULL; switch (type) { case ARITHMETIC: if (size() < 3) throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str()); - for (size_t i = 1; i < size(); ++i) - c.constrain(tenv, at(i), tenv.var(this)); + for (++i; i != end(); ++i) + c.constrain(tenv, *i, tenv.var(this)); break; case BINARY: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); - c.constrain(tenv, at(1), tenv.var(this)); - c.constrain(tenv, at(2), tenv.var(this)); + c.constrain(tenv, *++i, tenv.var(this)); + c.constrain(tenv, *++i, tenv.var(this)); break; case LOGICAL: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, this, tenv.named("Bool")); - c.constrain(tenv, at(1), tenv.named("Bool")); - c.constrain(tenv, at(2), tenv.named("Bool")); + c.constrain(tenv, *++i, tenv.named("Bool")); + c.constrain(tenv, *++i, tenv.named("Bool")); break; case COMPARISON: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); + var = tenv.var(*++i); c.constrain(tenv, this, tenv.named("Bool")); - c.constrain(tenv, at(1), tenv.var(at(2))); + c.constrain(tenv, *++i, var); break; default: throw Error(loc, (format("unknown primitive `%1%'") % n).str()); |