/* Tuplr Type Inferencing * Copyright (C) 2008-2009 David Robillard * * Tuplr is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Tuplr is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Tuplr. If not, see . */ #include #include "tuplr.hpp" void Constraints::constrain(TEnv& tenv, const AST* o, AType* t) { assert(!o->to()); push_back(Constraint(tenv.var(o), t, o->loc)); } /*************************************************************************** * AST Type Constraints * ***************************************************************************/ void ASymbol::constrain(TEnv& tenv, Constraints& c) const { addr = tenv.lookup(this); if (!addr) throw Error((format("undefined symbol `%1%'") % cppstr).str(), loc); pair& t = tenv.deref(addr); AType* tvar = tenv.var(t.second); c.push_back(Constraint(tenv.var(this), tvar, loc)); c.push_back(Constraint(t.second, tvar, loc)); } void ATuple::constrain(TEnv& tenv, Constraints& c) const { AType* t = new AType(loc, NULL); FOREACH(const_iterator, p, *this) { (*p)->constrain(tenv, c); t->push_back(tenv.var(*p)); } c.push_back(Constraint(tenv.var(this), t, loc)); } void AClosure::constrain(TEnv& tenv, Constraints& c) const { const AType* genericType; TEnv::GenericTypes::const_iterator gt = tenv.genericTypes.find(this); if (gt != tenv.genericTypes.end()) { genericType = gt->second; } else { set defined; TEnv::Frame frame; // Add parameters to environment frame for (size_t i = 0; i < prot()->size(); ++i) { ASymbol* sym = prot()->at(i)->to(); if (!sym) throw Error("parameter name is not a symbol", prot()->at(i)->loc); if (defined.find(sym) != defined.end()) throw Error((format("duplicate parameter `%1%'") % sym->str()).str(), sym->loc); defined.insert(sym); frame.push_back(make_pair(sym, make_pair((AST*)NULL, (AType*)NULL))); } // Add internal definitions to environment frame size_t e = 2; for (; e < size(); ++e) { AST* exp = at(e); ADefinition* def = exp->to(); if (def) { ASymbol* sym = def->sym(); if (defined.find(sym) != defined.end()) throw Error((format("`%1%' defined twice") % sym->str()).str(), def->loc); defined.insert(def->sym()); frame.push_back(make_pair(def->sym(), make_pair(def->at(2), (AType*)NULL))); } } tenv.push(frame); Constraints cp; cp.push_back(Constraint(tenv.var(this), tenv.var(), loc)); AType* protT = new AType(loc, NULL); for (size_t i = 0; i < prot()->size(); ++i) { AType* tvar = tenv.fresh(prot()->at(i)->to()); protT->push_back(tvar); assert(frame[i].first == prot()->at(i)); frame[i].second.first = prot()->at(i); frame[i].second.second = tvar; } c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc)); for (size_t i = 2; i < size(); ++i) at(i)->constrain(tenv, cp); AType* bodyT = tenv.var(at(e-1)); Subst tsubst = TEnv::unify(cp); genericType = new AType(loc, tenv.penv.sym("Fn"), tsubst.apply(protT), tsubst.apply(bodyT), 0); tenv.genericTypes.insert(make_pair(this, genericType)); tenv.pop(); subst = new Subst(tsubst); } c.constrain(tenv, this, new AType(*genericType)); } void ACall::constrain(TEnv& tenv, Constraints& c) const { at(0)->constrain(tenv, c); for (size_t i = 1; i < size(); ++i) at(i)->constrain(tenv, c); AST* callee = tenv.resolve(at(0)); AClosure* closure = callee->to(); if (closure) { if (size() - 1 != closure->prot()->size()) throw Error("incorrect number of arguments", loc); TEnv::GenericTypes::iterator gt = tenv.genericTypes.find(closure); if (gt != tenv.genericTypes.end()) { for (size_t i = 1; i < size(); ++i) c.constrain(tenv, at(i), gt->second->at(1)->as()->at(i-1)->as()); AType* retT = tenv.var(this); c.constrain(tenv, at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), tenv.var(), retT, 0)); c.constrain(tenv, this, retT); return; } } AType* argsT = new AType(loc, NULL); for (size_t i = 1; i < size(); ++i) argsT->push_back(tenv.var(at(i))); AType* retT = tenv.var(); c.constrain(tenv, at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0)); c.constrain(tenv, this, retT); } void ADefinition::constrain(TEnv& tenv, Constraints& c) const { if (size() != 3) throw Error("`def' requires exactly 2 arguments", loc); const ASymbol* sym = at(1)->to(); if (!sym) throw Error("`def' name is not a symbol", loc); AType* tvar = tenv.var(at(2)); tenv.def(sym, make_pair(at(2), tvar)); at(2)->constrain(tenv, c); c.constrain(tenv, this, tvar); } void AIf::constrain(TEnv& tenv, Constraints& c) const { if (size() < 3) throw Error("`if' requires exactly 3 arguments", loc); if (size() % 2 != 0) throw Error("`if' missing final else clause", loc); for (size_t i = 1; i < size(); ++i) at(i)->constrain(tenv, c); AType* retT = tenv.var(this); for (size_t i = 1; i < size(); i += 2) { if (i == size() - 1) { c.constrain(tenv, at(i), retT); } else { c.constrain(tenv, at(i), tenv.named("Bool")); c.constrain(tenv, at(i+1), retT); } } } void APrimitive::constrain(TEnv& tenv, Constraints& c) const { const string n = at(0)->to()->str(); enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type; if (n == "+" || n == "-" || n == "*" || n == "/") type = ARITHMETIC; else if (n == "%") type = BINARY; else if (n == "and" || n == "or" || n == "xor") type = LOGICAL; else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=") type = COMPARISON; else throw Error((format("unknown primitive `%1%'") % n).str(), loc); for (size_t i = 1; i < size(); ++i) at(i)->constrain(tenv, c); switch (type) { case ARITHMETIC: if (size() < 3) throw Error((format("`%1%' requires at least 2 arguments") % n).str(), loc); for (size_t i = 1; i < size(); ++i) c.constrain(tenv, at(i), tenv.var(this)); break; case BINARY: if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc); c.constrain(tenv, at(1), tenv.var(this)); c.constrain(tenv, at(2), tenv.var(this)); break; case LOGICAL: if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc); c.constrain(tenv, this, tenv.named("Bool")); c.constrain(tenv, at(1), tenv.named("Bool")); c.constrain(tenv, at(2), tenv.named("Bool")); break; case COMPARISON: if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc); c.constrain(tenv, this, tenv.named("Bool")); c.constrain(tenv, at(1), tenv.var(at(2))); break; default: throw Error((format("unknown primitive `%1%'") % n).str(), loc); } } void AConsCall::constrain(TEnv& tenv, Constraints& c) const { if (size() != 3) throw Error("`cons' requires exactly 2 arguments", loc); AType* t = new AType(loc, tenv.penv.sym("Pair"), 0); for (size_t i = 1; i < size(); ++i) { at(i)->constrain(tenv, c); t->push_back(tenv.var(at(i))); } c.constrain(tenv, this, t); } void ACarCall::constrain(TEnv& tenv, Constraints& c) const { if (size() != 2) throw Error("`car' requires exactly 1 argument", loc); at(1)->constrain(tenv, c); AType* carT = tenv.var(this); AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), carT, tenv.var(), 0); c.constrain(tenv, at(1), pairT); c.constrain(tenv, this, carT); } void ACdrCall::constrain(TEnv& tenv, Constraints& c) const { if (size() != 2) throw Error("`cdr' requires exactly 1 argument", loc); at(1)->constrain(tenv, c); AType* cdrT = tenv.var(this); AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), tenv.var(), cdrT, 0); c.constrain(tenv, at(1), pairT); c.constrain(tenv, this, cdrT); } /*************************************************************************** * Type Inferencing/Substitution * ***************************************************************************/ static void substitute(ATuple* tup, const AST* from, AST* to) { if (!tup) return; for (size_t i = 0; i < tup->size(); ++i) if (*tup->at(i) == *from) tup->at(i) = to; else if (tup->at(i) != to) substitute(tup->at(i)->to(), from, to); } Subst Subst::compose(const Subst& delta, const Subst& gamma) // TAPL 22.1.1 { Subst r; for (Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) { Subst::const_iterator d = delta.find(g->second); r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second)); } for (Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) { if (gamma.find(d->first) == gamma.end()) r.insert(*d); } return r; } void substConstraints(Constraints& constraints, AType* s, AType* t) { for (Constraints::iterator c = constraints.begin(); c != constraints.end();) { Constraints::iterator next = c; ++next; if (*c->first == *s) c->first = t; if (*c->second == *s) c->second = t; substitute(c->first, s, t); substitute(c->second, s, t); c = next; } } Subst TEnv::unify(const Constraints& constraints) // TAPL 22.4 { if (constraints.empty()) return Subst(); AType* s = constraints.begin()->first; AType* t = constraints.begin()->second; Constraints cp = constraints; cp.erase(cp.begin()); if (*s == *t) { return unify(cp); } else if (s->var() && !t->contains(s)) { substConstraints(cp, s, t); return Subst::compose(unify(cp), Subst(s, t)); } else if (t->var() && !s->contains(t)) { substConstraints(cp, t, s); return Subst::compose(unify(cp), Subst(t, s)); } else if (s->kind == AType::EXPR && s->kind == t->kind && s->size() == t->size()) { for (size_t i = 0; i < s->size(); ++i) { AType* si = s->at(i)->to(); AType* ti = t->at(i)->to(); if (si && ti) cp.push_back(Constraint(si, ti, si->loc)); } return unify(cp); } else { throw Error((format("type is `%1%' but should be `%2%'") % s->str() % t->str()).str(), s->loc ? s->loc : t->loc); } }