/* Tuplr Type Inferencing * Copyright (C) 2008-2009 David Robillard * * Tuplr is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Tuplr is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Tuplr. If not, see . */ #include #include "tuplr.hpp" void Constraints::constrain(TEnv& tenv, const AST* o, AType* t) { assert(!o->to()); push_back(Constraint(tenv.var(o), t, o->loc)); } /*************************************************************************** * AST Type Constraints * ***************************************************************************/ void ASymbol::constrain(TEnv& tenv, Constraints& c) const { addr = tenv.lookup(this); if (!addr) throw Error(loc, (format("undefined symbol `%1%'") % cppstr).str()); c.push_back(Constraint(tenv.var(this), tenv.deref(addr).second, loc)); } void ATuple::constrain(TEnv& tenv, Constraints& c) const { AType* t = tup(loc, NULL); FOREACH(const_iterator, p, *this) { (*p)->constrain(tenv, c); t->push_back(tenv.var(*p)); } c.push_back(Constraint(tenv.var(this), t, loc)); } void AFn::constrain(TEnv& tenv, Constraints& c) const { const AType* genericType; TEnv::GenericTypes::const_iterator gt = tenv.genericTypes.find(this); if (gt != tenv.genericTypes.end()) { genericType = gt->second; } else { set defined; TEnv::Frame frame; // Add parameters to environment frame for (size_t i = 0; i < prot()->size(); ++i) { ASymbol* sym = prot()->at(i)->to(); if (!sym) throw Error(prot()->at(i)->loc, "parameter name is not a symbol"); if (defined.find(sym) != defined.end()) throw Error(sym->loc, (format("duplicate parameter `%1%'") % sym->str()).str()); defined.insert(sym); frame.push_back(make_pair(sym, make_pair((AST*)NULL, (AType*)NULL))); } // Add internal definitions to environment frame size_t e = 2; for (; e < size(); ++e) { AST* exp = at(e); ADef* def = exp->to(); if (def) { ASymbol* sym = def->sym(); if (defined.find(sym) != defined.end()) throw Error(def->loc, (format("`%1%' defined twice") % sym->str()).str()); defined.insert(def->sym()); frame.push_back(make_pair(def->sym(), make_pair(def->at(2), (AType*)NULL))); } } tenv.push(frame); Constraints cp; cp.push_back(Constraint(tenv.var(this), tenv.var(), loc)); AType* protT = tup(loc, NULL); for (size_t i = 0; i < prot()->size(); ++i) { AType* tvar = tenv.fresh(prot()->at(i)->to()); protT->push_back(tvar); assert(frame[i].first == prot()->at(i)); frame[i].second.first = prot()->at(i); frame[i].second.second = tvar; } c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc)); for (size_t i = 2; i < size(); ++i) at(i)->constrain(tenv, cp); AType* bodyT = tenv.var(at(e-1)); Subst tsubst = TEnv::unify(cp); genericType = tup(loc, tenv.penv.sym("Fn"), tsubst.apply(protT), tsubst.apply(bodyT), 0); tenv.genericTypes.insert(make_pair(this, genericType)); Object::pool.addRoot(genericType); tenv.pop(); subst = tsubst; } AType* t = new AType(*genericType); // FIXME: deep copy c.constrain(tenv, this, t); } void ACall::constrain(TEnv& tenv, Constraints& c) const { at(0)->constrain(tenv, c); for (size_t i = 1; i < size(); ++i) at(i)->constrain(tenv, c); AST* callee = tenv.resolve(at(0)); AFn* closure = callee->to(); if (closure) { if (size() - 1 != closure->prot()->size()) throw Error(loc, "incorrect number of arguments"); TEnv::GenericTypes::iterator gt = tenv.genericTypes.find(closure); if (gt != tenv.genericTypes.end()) { for (size_t i = 1; i < size(); ++i) c.constrain(tenv, at(i), gt->second->at(1)->as()->at(i-1)->as()); AType* retT = tenv.var(this); c.constrain(tenv, at(0), tup(at(0)->loc, tenv.penv.sym("Fn"), tenv.var(), retT, 0)); c.constrain(tenv, this, retT); return; } } AType* argsT = tup(loc, 0); for (size_t i = 1; i < size(); ++i) argsT->push_back(tenv.var(at(i))); AType* retT = tenv.var(); c.constrain(tenv, at(0), tup(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0)); c.constrain(tenv, this, retT); } void ADef::constrain(TEnv& tenv, Constraints& c) const { THROW_IF(size() != 3, loc, "`def' requires exactly 2 arguments"); const ASymbol* sym = this->sym(); THROW_IF(!sym, loc, "`def' has no symbol") AType* tvar = tenv.var(at(2)); tenv.def(sym, make_pair(at(2), tvar)); at(2)->constrain(tenv, c); c.constrain(tenv, this, tvar); } void AIf::constrain(TEnv& tenv, Constraints& c) const { THROW_IF(size() < 4, loc, "`if' requires at least 3 arguments"); THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause") for (size_t i = 1; i < size(); ++i) at(i)->constrain(tenv, c); AType* retT = tenv.var(this); for (size_t i = 1; i < size(); i += 2) { if (i == size() - 1) { c.constrain(tenv, at(i), retT); } else { c.constrain(tenv, at(i), tenv.named("Bool")); c.constrain(tenv, at(i+1), retT); } } } void APrimitive::constrain(TEnv& tenv, Constraints& c) const { const string n = at(0)->to()->str(); enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type; if (n == "+" || n == "-" || n == "*" || n == "/") type = ARITHMETIC; else if (n == "%") type = BINARY; else if (n == "and" || n == "or" || n == "xor") type = LOGICAL; else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=") type = COMPARISON; else throw Error(loc, (format("unknown primitive `%1%'") % n).str()); for (size_t i = 1; i < size(); ++i) at(i)->constrain(tenv, c); switch (type) { case ARITHMETIC: if (size() < 3) throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str()); for (size_t i = 1; i < size(); ++i) c.constrain(tenv, at(i), tenv.var(this)); break; case BINARY: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, at(1), tenv.var(this)); c.constrain(tenv, at(2), tenv.var(this)); break; case LOGICAL: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, this, tenv.named("Bool")); c.constrain(tenv, at(1), tenv.named("Bool")); c.constrain(tenv, at(2), tenv.named("Bool")); break; case COMPARISON: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, this, tenv.named("Bool")); c.constrain(tenv, at(1), tenv.var(at(2))); break; default: throw Error(loc, (format("unknown primitive `%1%'") % n).str()); } }