aboutsummaryrefslogtreecommitdiffstats
path: root/src/constrain.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/constrain.cpp')
-rw-r--r--src/constrain.cpp73
1 files changed, 43 insertions, 30 deletions
diff --git a/src/constrain.cpp b/src/constrain.cpp
index 2fd3fa9..982f195 100644
--- a/src/constrain.cpp
+++ b/src/constrain.cpp
@@ -51,7 +51,7 @@ void
ATuple::constrain(TEnv& tenv, Constraints& c) const
{
AType* t = tup<AType>(loc, NULL);
- FOREACH(const_iterator, p, *this) {
+ FOREACHP(const_iterator, p, this) {
(*p)->constrain(tenv, c);
t->push_back(tenv.var(*p));
}
@@ -76,10 +76,12 @@ AFn::constrain(TEnv& tenv, Constraints& c) const
frame.push_back(make_pair(sym, tvar));
protT->push_back(tvar);
}
- c.constrain(tenv, at(1), protT);
+
+ const_iterator i = begin() + 1;
+ c.constrain(tenv, *i, protT);
// Add internal definitions to environment frame
- for (const_iterator i = begin() + 2; i != end(); ++i) {
+ for (++i; i != end(); ++i) {
const AST* exp = *i;
const ADef* def = exp->to<const ADef*>();
if (def) {
@@ -94,10 +96,11 @@ AFn::constrain(TEnv& tenv, Constraints& c) const
tenv.push(frame);
c.constrain(tenv, this, tenv.var());
- for (size_t i = 2; i < size(); ++i)
- at(i)->constrain(tenv, c);
+ AST* exp = NULL;
+ for (i = begin() + 2; i != end(); ++i)
+ (exp = *i)->constrain(tenv, c);
- AType* bodyT = tenv.var(at(size() - 1));
+ AType* bodyT = tenv.var(exp);
AType* fnT = tup<AType>(loc, tenv.penv.sym("Fn"), protT, bodyT, 0);
Object::pool.addRoot(fnT);
@@ -109,8 +112,8 @@ AFn::constrain(TEnv& tenv, Constraints& c) const
void
ACall::constrain(TEnv& tenv, Constraints& c) const
{
- for (size_t i = 0; i < size(); ++i)
- at(i)->constrain(tenv, c);
+ for (const_iterator i = begin(); i != end(); ++i)
+ (*i)->constrain(tenv, c);
const AType* fnType = tenv.var(head());
if (fnType->kind != AType::VAR) {
@@ -120,15 +123,15 @@ ACall::constrain(TEnv& tenv, Constraints& c) const
|| fnType->head()->to<const ASymbol*>()->cppstr != "Fn")
throw Error(loc, (format("call to non-function `%1%'") % head()->str()).str());
- size_t numArgs = fnType->at(1)->to<const ATuple*>()->size();
+ size_t numArgs = fnType->prot()->size();
THROW_IF(numArgs != size() - 1, loc,
(format("expected %1% arguments, got %2%") % numArgs % (size() - 1)).str());
}
AType* retT = tenv.var();
AType* argsT = tup<AType>(loc, 0);
- for (size_t i = 1; i < size(); ++i)
- argsT->push_back(tenv.var(at(i)));
+ for (const_iterator i = begin() + 1; i != end(); ++i)
+ argsT->push_back(tenv.var(*i));
c.constrain(tenv, head(), tup<AType>(head()->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
c.constrain(tenv, this, retT);
@@ -141,10 +144,10 @@ ADef::constrain(TEnv& tenv, Constraints& c) const
const ASymbol* sym = this->sym();
THROW_IF(!sym, loc, "`def' has no symbol")
- AType* tvar = tenv.var(at(2));
+ AType* tvar = tenv.var(body());
tenv.def(sym, tvar);
- at(2)->constrain(tenv, c);
- c.constrain(tenv, at(1), tvar);
+ body()->constrain(tenv, c);
+ c.constrain(tenv, sym, tvar);
c.constrain(tenv, this, tenv.named("Nothing"));
}
@@ -153,16 +156,20 @@ AIf::constrain(TEnv& tenv, Constraints& c) const
{
THROW_IF(size() < 4, loc, "`if' requires at least 3 arguments");
THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause")
- for (size_t i = 1; i < size(); ++i)
- at(i)->constrain(tenv, c);
+ for (const_iterator i = begin() + 1; i != end(); ++i)
+ (*i)->constrain(tenv, c);
AType* retT = tenv.var(this);
- for (size_t i = 1; i < size(); i += 2) {
- if (i == size() - 1) {
- c.constrain(tenv, at(i), retT);
+ for (const_iterator i = begin() + 1; true; ++i) {
+ const_iterator next = i;
+ ++next;
+ if (next == end()) { // final (else) expression
+ c.constrain(tenv, *i, retT);
+ break;
} else {
- c.constrain(tenv, at(i), tenv.named("Bool"));
- c.constrain(tenv, at(i+1), retT);
+ c.constrain(tenv, *i, tenv.named("Bool"));
+ c.constrain(tenv, *next, retT);
}
+ i = next; // jump 2 each iteration (to the next predicate)
}
}
@@ -182,34 +189,40 @@ APrimitive::constrain(TEnv& tenv, Constraints& c) const
else
throw Error(loc, (format("unknown primitive `%1%'") % n).str());
- for (size_t i = 1; i < size(); ++i)
- at(i)->constrain(tenv, c);
+ const_iterator i = begin();
+
+ for (++i; i != end(); ++i)
+ (*i)->constrain(tenv, c);
+
+ i = begin();
+ AType* var = NULL;
switch (type) {
case ARITHMETIC:
if (size() < 3)
throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str());
- for (size_t i = 1; i < size(); ++i)
- c.constrain(tenv, at(i), tenv.var(this));
+ for (++i; i != end(); ++i)
+ c.constrain(tenv, *i, tenv.var(this));
break;
case BINARY:
if (size() != 3)
throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
- c.constrain(tenv, at(1), tenv.var(this));
- c.constrain(tenv, at(2), tenv.var(this));
+ c.constrain(tenv, *++i, tenv.var(this));
+ c.constrain(tenv, *++i, tenv.var(this));
break;
case LOGICAL:
if (size() != 3)
throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
c.constrain(tenv, this, tenv.named("Bool"));
- c.constrain(tenv, at(1), tenv.named("Bool"));
- c.constrain(tenv, at(2), tenv.named("Bool"));
+ c.constrain(tenv, *++i, tenv.named("Bool"));
+ c.constrain(tenv, *++i, tenv.named("Bool"));
break;
case COMPARISON:
if (size() != 3)
throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
+ var = tenv.var(*++i);
c.constrain(tenv, this, tenv.named("Bool"));
- c.constrain(tenv, at(1), tenv.var(at(2)));
+ c.constrain(tenv, *++i, var);
break;
default:
throw Error(loc, (format("unknown primitive `%1%'") % n).str());