diff options
Diffstat (limited to 'src/constrain.cpp')
-rw-r--r-- | src/constrain.cpp | 248 |
1 files changed, 140 insertions, 108 deletions
diff --git a/src/constrain.cpp b/src/constrain.cpp index bd43b8e..4e507c9 100644 --- a/src/constrain.cpp +++ b/src/constrain.cpp @@ -46,13 +46,6 @@ ALexeme::constrain(TEnv& tenv, Constraints& c) const throw(Error) } void -AQuote::constrain(TEnv& tenv, Constraints& c) const throw(Error) -{ - c.constrain(tenv, this, tenv.named("Quote")); - list_ref(1)->constrain(tenv, c); -} - -void ASymbol::constrain(TEnv& tenv, Constraints& c) const throw(Error) { const AType** ref = tenv.ref(this); @@ -125,78 +118,58 @@ AFn::constrain(TEnv& tenv, Constraints& c) const throw(Error) c.constrain(tenv, this, fnT); } -void -ACall::constrain(TEnv& tenv, Constraints& c) const throw(Error) -{ - for (const_iterator i = begin(); i != end(); ++i) - (*i)->constrain(tenv, c); - - const AType* fnType = tenv.var(head()); - if (fnType->kind != AType::VAR) { - if (fnType->kind == AType::PRIM - || fnType->list_len() < 2 - || fnType->head()->str() != "Fn") - throw Error(loc, (format("call to non-function `%1%'") % head()->str()).str()); - - size_t numArgs = fnType->prot()->list_len(); - THROW_IF(numArgs != list_len() - 1, loc, - (format("expected %1% arguments, got %2%") % numArgs % (list_len() - 1)).str()); - } - - const AType* retT = tenv.var(this); - TList argsT; - for (const_iterator i = iter_at(1); i != end(); ++i) - argsT.push_back(const_cast<AType*>(tenv.var(*i))); - argsT.head->loc = loc; - c.constrain(tenv, head(), tup<AType>(head()->loc, tenv.Fn, argsT.head, retT, 0)); - c.constrain(tenv, this, retT); -} - -void -ADef::constrain(TEnv& tenv, Constraints& c) const throw(Error) +static void +constrain_def(TEnv& tenv, Constraints& c, const ACall* call) throw(Error) { - THROW_IF(list_len() != 3, loc, "`def' requires exactly 2 arguments"); - const ASymbol* sym = this->list_ref(1)->as<const ASymbol*>(); - THROW_IF(!sym, loc, "`def' has no symbol") + THROW_IF(call->list_len() != 3, call->loc, "`def' requires exactly 2 arguments"); + const ASymbol* const sym = call->list_ref(1)->as<const ASymbol*>(); + THROW_IF(!sym, call->loc, "`def' has no symbol") + const AST* const body = call->list_ref(2); - const AType* tvar = tenv.var(body()); + const AType* tvar = tenv.var(body); tenv.def(sym, tvar); - body()->constrain(tenv, c); + body->constrain(tenv, c); c.constrain(tenv, sym, tvar); - c.constrain(tenv, this, tenv.named("Nothing")); + c.constrain(tenv, call, tenv.named("Nothing")); } -void -AIf::constrain(TEnv& tenv, Constraints& c) const throw(Error) +static void +constrain_def_type(TEnv& tenv, Constraints& c, const ACall* call) throw(Error) { - THROW_IF(list_len() < 4, loc, "`if' requires at least 3 arguments"); - THROW_IF(list_len() % 2 != 0, loc, "`if' missing final else clause"); - for (const_iterator i = iter_at(1); i != end(); ++i) - (*i)->constrain(tenv, c); - const AType* retT = tenv.var(this); - for (const_iterator i = iter_at(1); true; ++i) { - const_iterator next = i; - ++next; - if (next == end()) { // final (else) expression - c.constrain(tenv, *i, retT); - break; - } else { - c.constrain(tenv, *i, tenv.named("Bool")); - c.constrain(tenv, *next, retT); + THROW_IF(call->list_len() < 3, call->loc, "`def-type' requires at least 2 arguments"); + ATuple::const_iterator i = call->iter_at(1); + const ATuple* prot = (*i)->to<const ATuple*>(); + THROW_IF(!prot, (*i)->loc, "first argument of `def-type' is not a tuple"); + const ASymbol* sym = (*prot->begin())->as<const ASymbol*>(); + THROW_IF(!sym, (*prot->begin())->loc, "type name is not a symbol"); + THROW_IF(tenv.ref(sym), call->loc, "type redefinition"); + TList type(new AType(tenv.U, NULL, call->loc)); + for (ATuple::const_iterator i = call->iter_at(2); i != call->end(); ++i) { + const ATuple* exp = (*i)->as<const ATuple*>(); + const ASymbol* tag = (*exp->begin())->as<const ASymbol*>(); + TList consT; + consT.push_back(new AType(const_cast<ASymbol*>(sym), AType::NAME)); + for (ATuple::const_iterator i = exp->begin(); i != exp->end(); ++i) { + const ASymbol* sym = (*i)->to<const ASymbol*>(); + THROW_IF(!sym, (*i)->loc, "type expression element is not a symbol"); + consT.push_back(new AType(const_cast<ASymbol*>(sym), AType::NAME)); } - i = next; // jump 2 each iteration (to the next predicate) + consT.head->loc = exp->loc; + type.push_back(consT); + tenv.def(tag, consT); } + tenv.def(sym, type); } -void -AMatch::constrain(TEnv& tenv, Constraints& c) const throw(Error) +static void +constrain_match(TEnv& tenv, Constraints& c, const ACall* call) throw(Error) { - THROW_IF(list_len() < 5, loc, "`match' requires at least 4 arguments"); - const AST* matchee = list_ref(1); + THROW_IF(call->list_len() < 5, call->loc, "`match' requires at least 4 arguments"); + const AST* matchee = call->list_ref(1); const AType* retT = tenv.var(); const AType* matcheeT = NULL;// = tup<AType>(loc, tenv.U, 0); matchee->constrain(tenv, c); - for (const_iterator i = iter_at(2); i != end();) { + for (ATuple::const_iterator i = call->iter_at(2); i != call->end();) { const AST* exp = *i++; const ATuple* pattern = exp->to<const ATuple*>(); THROW_IF(!pattern, exp->loc, "pattern expression expected"); @@ -207,84 +180,79 @@ AMatch::constrain(TEnv& tenv, Constraints& c) const throw(Error) if (!matcheeT) { const AType* headT = consT->head()->as<const AType*>(); - matcheeT = tup<AType>(loc, const_cast<AType*>(headT), 0); + matcheeT = tup<AType>(call->loc, const_cast<AType*>(headT), 0); } - THROW_IF(i == end(), pattern->loc, "missing pattern body"); + THROW_IF(i == call->end(), pattern->loc, "missing pattern body"); const AST* body = *i++; body->constrain(tenv, c); c.constrain(tenv, body, retT); } - c.constrain(tenv, this, retT); + c.constrain(tenv, call, retT); c.constrain(tenv, matchee, matcheeT); } -void -ADefType::constrain(TEnv& tenv, Constraints& c) const throw(Error) +static void +constrain_if(TEnv& tenv, Constraints& c, const ACall* call) throw(Error) { - THROW_IF(list_len() < 3, loc, "`def-type' requires at least 2 arguments"); - const_iterator i = iter_at(1); - const ATuple* prot = (*i)->to<const ATuple*>(); - THROW_IF(!prot, (*i)->loc, "first argument of `def-type' is not a tuple"); - const ASymbol* sym = (*prot->begin())->as<const ASymbol*>(); - THROW_IF(!sym, (*prot->begin())->loc, "type name is not a symbol"); - THROW_IF(tenv.ref(sym), loc, "type redefinition"); - TList type(new AType(tenv.U, NULL, loc)); - for (const_iterator i = iter_at(2); i != end(); ++i) { - const ATuple* exp = (*i)->as<const ATuple*>(); - const ASymbol* tag = (*exp->begin())->as<const ASymbol*>(); - TList consT; - consT.push_back(new AType(const_cast<ASymbol*>(sym), AType::NAME)); - for (ATuple::const_iterator i = exp->begin(); i != exp->end(); ++i) { - const ASymbol* sym = (*i)->to<const ASymbol*>(); - THROW_IF(!sym, (*i)->loc, "type expression element is not a symbol"); - consT.push_back(new AType(const_cast<ASymbol*>(sym), AType::NAME)); + THROW_IF(call->list_len() < 4, call->loc, "`if' requires at least 3 arguments"); + THROW_IF(call->list_len() % 2 != 0, call->loc, "`if' missing final else clause"); + for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) + (*i)->constrain(tenv, c); + const AType* retT = tenv.var(call); + for (ATuple::const_iterator i = call->iter_at(1); true; ++i) { + ATuple::const_iterator next = i; + ++next; + if (next == call->end()) { // final (else) expression + c.constrain(tenv, *i, retT); + break; + } else { + c.constrain(tenv, *i, tenv.named("Bool")); + c.constrain(tenv, *next, retT); } - consT.head->loc = exp->loc; - type.push_back(consT); - tenv.def(tag, consT); + i = next; // jump 2 each iteration (to the next predicate) } - tenv.def(sym, type); } -void -ACons::constrain(TEnv& tenv, Constraints& c) const throw(Error) +static void +constrain_cons(TEnv& tenv, Constraints& c, const ACall* call) throw(Error) { - const ASymbol* sym = (*begin())->as<const ASymbol*>(); + const ASymbol* sym = (*call->begin())->as<const ASymbol*>(); const AType* type = NULL; - for (const_iterator i = iter_at(1); i != end(); ++i) + for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) (*i)->constrain(tenv, c); if (sym->cppstr == "Tup") { - TList tupT(new AType(tenv.Tup, NULL, loc)); - for (const_iterator i = iter_at(1); i != end(); ++i) { + TList tupT(new AType(tenv.Tup, NULL, call->loc)); + for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) { tupT.push_back(const_cast<AType*>(tenv.var(*i))); } type = tupT; } else { const AType** consTRef = tenv.ref(sym); - THROW_IF(!consTRef, loc, (format("call to undefined constructor `%1%'") % sym->cppstr).str()); + THROW_IF(!consTRef, call->loc, + (format("call to undefined constructor `%1%'") % sym->cppstr).str()); const AType* consT = *consTRef; - type = tup<AType>(loc, const_cast<AType*>(consT->head()->as<const AType*>()), 0); + type = tup<AType>(call->loc, const_cast<AType*>(consT->head()->as<const AType*>()), 0); } - c.constrain(tenv, this, type); + c.constrain(tenv, call, type); } -void -ADot::constrain(TEnv& tenv, Constraints& c) const throw(Error) +static void +constrain_dot(TEnv& tenv, Constraints& c, const ACall* call) throw(Error) { - THROW_IF(list_len() != 3, loc, "`.' requires exactly 2 arguments"); - const_iterator i = begin(); + THROW_IF(call->list_len() != 3, call->loc, "`.' requires exactly 2 arguments"); + ATuple::const_iterator i = call->begin(); const AST* obj = *++i; const ALiteral<int32_t>* idx = (*++i)->to<const ALiteral<int32_t>*>(); - THROW_IF(!idx, loc, "the 2nd argument to `.' must be a literal integer"); + THROW_IF(!idx, call->loc, "the 2nd argument to `.' must be a literal integer"); obj->constrain(tenv, c); - const AType* retT = tenv.var(this); - c.constrain(tenv, this, retT); + const AType* retT = tenv.var(call); + c.constrain(tenv, call, retT); - TList objT(new AType(tenv.Tup, NULL, loc)); + TList objT(new AType(tenv.Tup, NULL, call->loc)); for (int i = 0; i < idx->val; ++i) objT.push_back(const_cast<AType*>(tenv.var())); objT.push_back(const_cast<AType*>(retT)); @@ -292,6 +260,70 @@ ADot::constrain(TEnv& tenv, Constraints& c) const throw(Error) c.constrain(tenv, obj, objT); } +static void +constrain_quote(TEnv& tenv, Constraints& c, const ACall* call) throw(Error) +{ + c.constrain(tenv, call, tenv.named("Quote")); + call->list_ref(1)->constrain(tenv, c); +} + +static void +constrain_call(TEnv& tenv, Constraints& c, const ACall* call) throw(Error) +{ + const AST* const head = call->head(); + + for (ATuple::const_iterator i = call->begin(); i != call->end(); ++i) + (*i)->constrain(tenv, c); + + const AType* fnType = tenv.var(head); + if (fnType->kind != AType::VAR) { + if (fnType->kind == AType::PRIM + || fnType->list_len() < 2 + || fnType->head()->str() != "Fn") + throw Error(call->loc, (format("call to non-function `%1%'") % head->str()).str()); + + size_t numArgs = fnType->prot()->list_len(); + THROW_IF(numArgs != call->list_len() - 1, call->loc, + (format("expected %1% arguments, got %2%") % numArgs % (call->list_len() - 1)).str()); + } + + const AType* retT = tenv.var(call); + TList argsT; + for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) + argsT.push_back(const_cast<AType*>(tenv.var(*i))); + argsT.head->loc = call->loc; + c.constrain(tenv, head, tup<AType>(head->loc, tenv.Fn, argsT.head, retT, 0)); + c.constrain(tenv, call, retT); +} + +void +ACall::constrain(TEnv& tenv, Constraints& c) const throw(Error) +{ + const ASymbol* const sym = head()->to<const ASymbol*>(); + if (!sym) { + constrain_call(tenv, c, this); + return; + } + + const std::string form = sym->cppstr; + if (form == "def") + constrain_def(tenv, c, this); + else if (form == "def-type") + constrain_def_type(tenv, c, this); + else if (form == "match") + constrain_match(tenv, c, this); + else if (form == "if") + constrain_if(tenv, c, this); + else if (form == "cons" || isupper(form[0])) + constrain_cons(tenv, c, this); + else if (form == ".") + constrain_dot(tenv, c, this); + else if (form == "quote") + constrain_quote(tenv, c, this); + else + constrain_call(tenv, c, this); +} + void APrimitive::constrain(TEnv& tenv, Constraints& c) const throw(Error) { @@ -308,7 +340,7 @@ APrimitive::constrain(TEnv& tenv, Constraints& c) const throw(Error) else throw Error(loc, (format("unknown primitive `%1%'") % n).str()); - const_iterator i = begin(); + ATuple::const_iterator i = begin(); for (++i; i != end(); ++i) (*i)->constrain(tenv, c); |