aboutsummaryrefslogtreecommitdiffstats
path: root/src/constrain.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/constrain.cpp')
-rw-r--r--src/constrain.cpp131
1 files changed, 84 insertions, 47 deletions
diff --git a/src/constrain.cpp b/src/constrain.cpp
index 78accf6..83a027a 100644
--- a/src/constrain.cpp
+++ b/src/constrain.cpp
@@ -35,26 +35,57 @@ constrain_symbol(TEnv& tenv, Constraints& c, const ASymbol* sym) throw(Error)
static void
constrain_cons(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
- const ASymbol* sym = (*call->begin())->as_symbol();
- const AST* type = NULL;
+ const ASymbol* name = (*call->begin())->as_symbol();
+ // Constrain each argument
for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i)
- resp_constrain(tenv, c, *i);
+ resp_constrain(tenv, c, *i); // ::= ?Targi
- if (!strcmp(sym->sym(), "Tup")) {
- List tupT(new ATuple(tenv.Tup, NULL, call->loc));
+ if (!strcmp(name->sym(), "Tup")) {
+ // Build a type expression like (Tup ?Targ1 ...)
+ List tupT(new ATuple(name, NULL, call->loc));
for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) {
tupT.push_back(tenv.var(*i));
}
- type = tupT;
+ c.constrain(tenv, call, tupT);
} else {
- const AST** consTRef = tenv.ref(sym);
- THROW_IF(!consTRef, call->loc,
- (format("call to undefined constructor `%1%'") % sym->sym()).str());
- const AST* consT = *consTRef;
- type = new ATuple(consT->as_tuple()->fst(), 0, call->loc);
+ // Look up constructor and use its type
+ TEnv::Tags::const_iterator tag = tenv.tags.find(name->str());
+ THROW_IF(tag == tenv.tags.end(), name->loc,
+ (format("undefined constructor `%1%'") % name->sym()).str());
+
+ // Build a substitution for every tvar in the constructor pattern
+ Subst subst;
+ const ATuple* expr = tag->second.expr->as_tuple();
+ ATuple::const_iterator e = expr->iter_at(1);
+ for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i, ++e) {
+ const ASymbol* sym = (*e)->to_symbol();
+ if (sym && !isupper(sym->str()[0])) {
+ // Argument corresponds to type variable in constructor pattern
+ subst.add(*e, tenv.var(*i));
+ }
+ }
+
+ // Substitute tvar symbols with the tvar for the corresponding argument
+ const AST* pattern = subst.apply(tag->second.type);
+
+ // Replace remaining tvar symbols with a free tvar
+ for (ATuple::const_iterator i = pattern->as_tuple()->iter_at(1);
+ i != pattern->as_tuple()->end(); ++i) {
+ const ASymbol* sym = (*i)->to_symbol();
+ if (sym && islower(sym->str()[0])) {
+ subst.add(sym, tenv.var());
+ }
+ }
+
+ // Constrain every argument to the corresponding pattern element
+ e = expr->iter_at(1);
+ for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i, ++e) {
+ c.constrain(tenv, *i, subst.apply(*e));
+ }
+
+ c.constrain(tenv, call, subst.apply(pattern));
}
- c.constrain(tenv, call, type);
}
static void
@@ -105,18 +136,12 @@ constrain_def_type(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
const ASymbol* sym = (*prot->begin())->as_symbol();
THROW_IF(!sym, (*prot->begin())->loc, "type name is not a symbol");
THROW_IF(tenv.ref(sym), call->loc, "type redefinition");
- List type(new ATuple(tenv.U, NULL, call->loc));
+ List type(call->loc, tenv.penv.sym("Lambda"), prot->rst(), NULL);
for (ATuple::const_iterator i = call->iter_at(2); i != call->end(); ++i) {
- const ATuple* exp = (*i)->as_tuple();
- const ASymbol* tag = (*exp->begin())->as_symbol();
- List consT;
- consT.push_back(sym);
- for (ATuple::const_iterator i = exp->begin(); i != exp->end(); ++i) {
- consT.push_back(*i); // FIXME: ensure symbol, or list of symbol
- }
- consT.head->loc = exp->loc;
- type.push_back(consT);
- tenv.def(tag, consT);
+ const ATuple* exp = (*i)->as_tuple();
+ const ASymbol* tag = (*exp->begin())->as_symbol();
+ tenv.tags.insert(std::make_pair(tag->str(), TEnv::Constructor(exp, prot)));
+ type.push_back(exp);
}
tenv.def(sym, type);
}
@@ -233,35 +258,42 @@ constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
THROW_IF(call->list_len() < 5, call->loc, "`match' requires at least 4 arguments");
const AST* matchee = call->list_ref(1);
const AST* retT = tenv.var();
- const AST* matcheeT = NULL;
+ const AST* matcheeT = tenv.var();
resp_constrain(tenv, c, matchee);
for (ATuple::const_iterator i = call->iter_at(2); i != call->end();) {
const AST* exp = *i++;
const ATuple* pattern = exp->to_tuple();
- THROW_IF(!pattern, exp->loc, "pattern expression expected");
+ THROW_IF(!pattern, exp->loc, "missing pattern");
+ THROW_IF(i == call->end(), pattern->loc, "missing expression");
+
+ const AST* body = *i++;
const ASymbol* name = (*pattern->begin())->to_symbol();
THROW_IF(!name, (*pattern->begin())->loc, "pattern does not start with a symbol");
- THROW_IF(!tenv.ref(name), name->loc,
- (format("undefined constructor `%1%'") % name->sym()).str());
- const AST* consT = *tenv.ref(name);
+ TEnv::Tags::const_iterator tag = tenv.tags.find(name->str());
+ THROW_IF(tag == tenv.tags.end(), name->loc,
+ (format("undefined constructor `%1%'") % name->sym()).str());
- if (!matcheeT) {
- const AST* headT = consT->as_tuple()->fst();
- matcheeT = new ATuple(headT, 0, call->loc);
+ const TEnv::Constructor& constructor = tag->second;
+ TEnv::Frame frame;
+ ATuple::const_iterator ei = constructor.expr->as_tuple()->iter_at(1);
+ for (ATuple::const_iterator pi = pattern->iter_at(1); pi != pattern->end(); ++pi, ++ei) {
+ const AST* tvar = tenv.var(*pi);
+ frame.push_back(make_pair((*pi)->as_symbol()->sym(), tvar));
}
- THROW_IF(i == call->end(), pattern->loc, "missing pattern body");
- const AST* body = *i++;
-
- TEnv::Frame frame;
- ATuple::const_iterator ti = consT->as_tuple()->iter_at(2);
- for (ATuple::const_iterator pi = pattern->iter_at(1); pi != pattern->end(); ++pi)
- frame.push_back(make_pair((*pi)->as_symbol()->sym(), *ti++));
-
tenv.push(frame);
resp_constrain(tenv, c, body);
c.constrain(tenv, body, retT);
+
+ // Copy the type's prototype replacing symbols with real type variables
+ List type(matchee->loc, constructor.type->as_tuple()->fst(), NULL);
+ for (ATuple::const_iterator t = constructor.type->as_tuple()->iter_at(1);
+ t != constructor.type->as_tuple()->end(); ++t) {
+ type.push_back(tenv.var());
+ }
+
+ c.constrain(tenv, matchee, type);
tenv.pop();
}
c.constrain(tenv, call, retT);
@@ -271,16 +303,21 @@ constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
static void
resp_constrain_quoted(TEnv& tenv, Constraints& c, const AST* ast) throw(Error)
{
- switch (ast->tag()) {
- case T_SYMBOL:
+ if (ast->tag() == T_SYMBOL) {
c.constrain(tenv, ast, tenv.named("Symbol"));
- return;
- case T_TUPLE:
- c.constrain(tenv, ast, tenv.named("List"));
- FOREACHP(ATuple::const_iterator, i, ast->as_tuple())
+ } else if (ast->tag() == T_TUPLE) {
+ List tupT(new ATuple(tenv.List, NULL, ast->loc));
+ const ATuple* tup = ast->as_tuple();
+ const AST* fstT = tenv.var(tup->fst());
+
+ tupT.push_back(tenv.penv.sym("Expr"));
+ c.constrain(tenv, ast, tupT);
+ c.constrain(tenv, tup->fst(), fstT);
+ FOREACHP(ATuple::const_iterator, i, ast->as_tuple()) {
resp_constrain_quoted(tenv, c, *i);
- return;
- default:
+ }
+
+ } else {
resp_constrain(tenv, c, ast);
}
}