diff options
Diffstat (limited to 'src/constrain.cpp')
-rw-r--r-- | src/constrain.cpp | 131 |
1 files changed, 84 insertions, 47 deletions
diff --git a/src/constrain.cpp b/src/constrain.cpp index 78accf6..83a027a 100644 --- a/src/constrain.cpp +++ b/src/constrain.cpp @@ -35,26 +35,57 @@ constrain_symbol(TEnv& tenv, Constraints& c, const ASymbol* sym) throw(Error) static void constrain_cons(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { - const ASymbol* sym = (*call->begin())->as_symbol(); - const AST* type = NULL; + const ASymbol* name = (*call->begin())->as_symbol(); + // Constrain each argument for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) - resp_constrain(tenv, c, *i); + resp_constrain(tenv, c, *i); // ::= ?Targi - if (!strcmp(sym->sym(), "Tup")) { - List tupT(new ATuple(tenv.Tup, NULL, call->loc)); + if (!strcmp(name->sym(), "Tup")) { + // Build a type expression like (Tup ?Targ1 ...) + List tupT(new ATuple(name, NULL, call->loc)); for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) { tupT.push_back(tenv.var(*i)); } - type = tupT; + c.constrain(tenv, call, tupT); } else { - const AST** consTRef = tenv.ref(sym); - THROW_IF(!consTRef, call->loc, - (format("call to undefined constructor `%1%'") % sym->sym()).str()); - const AST* consT = *consTRef; - type = new ATuple(consT->as_tuple()->fst(), 0, call->loc); + // Look up constructor and use its type + TEnv::Tags::const_iterator tag = tenv.tags.find(name->str()); + THROW_IF(tag == tenv.tags.end(), name->loc, + (format("undefined constructor `%1%'") % name->sym()).str()); + + // Build a substitution for every tvar in the constructor pattern + Subst subst; + const ATuple* expr = tag->second.expr->as_tuple(); + ATuple::const_iterator e = expr->iter_at(1); + for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i, ++e) { + const ASymbol* sym = (*e)->to_symbol(); + if (sym && !isupper(sym->str()[0])) { + // Argument corresponds to type variable in constructor pattern + subst.add(*e, tenv.var(*i)); + } + } + + // Substitute tvar symbols with the tvar for the corresponding argument + const AST* pattern = subst.apply(tag->second.type); + + // Replace remaining tvar symbols with a free tvar + for (ATuple::const_iterator i = pattern->as_tuple()->iter_at(1); + i != pattern->as_tuple()->end(); ++i) { + const ASymbol* sym = (*i)->to_symbol(); + if (sym && islower(sym->str()[0])) { + subst.add(sym, tenv.var()); + } + } + + // Constrain every argument to the corresponding pattern element + e = expr->iter_at(1); + for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i, ++e) { + c.constrain(tenv, *i, subst.apply(*e)); + } + + c.constrain(tenv, call, subst.apply(pattern)); } - c.constrain(tenv, call, type); } static void @@ -105,18 +136,12 @@ constrain_def_type(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) const ASymbol* sym = (*prot->begin())->as_symbol(); THROW_IF(!sym, (*prot->begin())->loc, "type name is not a symbol"); THROW_IF(tenv.ref(sym), call->loc, "type redefinition"); - List type(new ATuple(tenv.U, NULL, call->loc)); + List type(call->loc, tenv.penv.sym("Lambda"), prot->rst(), NULL); for (ATuple::const_iterator i = call->iter_at(2); i != call->end(); ++i) { - const ATuple* exp = (*i)->as_tuple(); - const ASymbol* tag = (*exp->begin())->as_symbol(); - List consT; - consT.push_back(sym); - for (ATuple::const_iterator i = exp->begin(); i != exp->end(); ++i) { - consT.push_back(*i); // FIXME: ensure symbol, or list of symbol - } - consT.head->loc = exp->loc; - type.push_back(consT); - tenv.def(tag, consT); + const ATuple* exp = (*i)->as_tuple(); + const ASymbol* tag = (*exp->begin())->as_symbol(); + tenv.tags.insert(std::make_pair(tag->str(), TEnv::Constructor(exp, prot))); + type.push_back(exp); } tenv.def(sym, type); } @@ -233,35 +258,42 @@ constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) THROW_IF(call->list_len() < 5, call->loc, "`match' requires at least 4 arguments"); const AST* matchee = call->list_ref(1); const AST* retT = tenv.var(); - const AST* matcheeT = NULL; + const AST* matcheeT = tenv.var(); resp_constrain(tenv, c, matchee); for (ATuple::const_iterator i = call->iter_at(2); i != call->end();) { const AST* exp = *i++; const ATuple* pattern = exp->to_tuple(); - THROW_IF(!pattern, exp->loc, "pattern expression expected"); + THROW_IF(!pattern, exp->loc, "missing pattern"); + THROW_IF(i == call->end(), pattern->loc, "missing expression"); + + const AST* body = *i++; const ASymbol* name = (*pattern->begin())->to_symbol(); THROW_IF(!name, (*pattern->begin())->loc, "pattern does not start with a symbol"); - THROW_IF(!tenv.ref(name), name->loc, - (format("undefined constructor `%1%'") % name->sym()).str()); - const AST* consT = *tenv.ref(name); + TEnv::Tags::const_iterator tag = tenv.tags.find(name->str()); + THROW_IF(tag == tenv.tags.end(), name->loc, + (format("undefined constructor `%1%'") % name->sym()).str()); - if (!matcheeT) { - const AST* headT = consT->as_tuple()->fst(); - matcheeT = new ATuple(headT, 0, call->loc); + const TEnv::Constructor& constructor = tag->second; + TEnv::Frame frame; + ATuple::const_iterator ei = constructor.expr->as_tuple()->iter_at(1); + for (ATuple::const_iterator pi = pattern->iter_at(1); pi != pattern->end(); ++pi, ++ei) { + const AST* tvar = tenv.var(*pi); + frame.push_back(make_pair((*pi)->as_symbol()->sym(), tvar)); } - THROW_IF(i == call->end(), pattern->loc, "missing pattern body"); - const AST* body = *i++; - - TEnv::Frame frame; - ATuple::const_iterator ti = consT->as_tuple()->iter_at(2); - for (ATuple::const_iterator pi = pattern->iter_at(1); pi != pattern->end(); ++pi) - frame.push_back(make_pair((*pi)->as_symbol()->sym(), *ti++)); - tenv.push(frame); resp_constrain(tenv, c, body); c.constrain(tenv, body, retT); + + // Copy the type's prototype replacing symbols with real type variables + List type(matchee->loc, constructor.type->as_tuple()->fst(), NULL); + for (ATuple::const_iterator t = constructor.type->as_tuple()->iter_at(1); + t != constructor.type->as_tuple()->end(); ++t) { + type.push_back(tenv.var()); + } + + c.constrain(tenv, matchee, type); tenv.pop(); } c.constrain(tenv, call, retT); @@ -271,16 +303,21 @@ constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) static void resp_constrain_quoted(TEnv& tenv, Constraints& c, const AST* ast) throw(Error) { - switch (ast->tag()) { - case T_SYMBOL: + if (ast->tag() == T_SYMBOL) { c.constrain(tenv, ast, tenv.named("Symbol")); - return; - case T_TUPLE: - c.constrain(tenv, ast, tenv.named("List")); - FOREACHP(ATuple::const_iterator, i, ast->as_tuple()) + } else if (ast->tag() == T_TUPLE) { + List tupT(new ATuple(tenv.List, NULL, ast->loc)); + const ATuple* tup = ast->as_tuple(); + const AST* fstT = tenv.var(tup->fst()); + + tupT.push_back(tenv.penv.sym("Expr")); + c.constrain(tenv, ast, tupT); + c.constrain(tenv, tup->fst(), fstT); + FOREACHP(ATuple::const_iterator, i, ast->as_tuple()) { resp_constrain_quoted(tenv, c, *i); - return; - default: + } + + } else { resp_constrain(tenv, c, ast); } } |