/* Tuplr Type Inferencing * Copyright (C) 2008-2009 David Robillard * * Tuplr is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Tuplr is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Tuplr. If not, see . */ #include "tuplr.hpp" /*************************************************************************** * AST Type Constraints * ***************************************************************************/ void ASTTuple::constrain(TEnv& tenv) const { AType* t = new AType(ASTTuple(), loc); FOREACH(const_iterator, p, *this) { (*p)->constrain(tenv); t->push_back(tenv.type(*p)); } tenv.constrain(this, t); } void ASTClosure::constrain(TEnv& tenv) const { at(1)->constrain(tenv); at(2)->constrain(tenv); AType* protT = tenv.type(at(1)); AType* bodyT = tenv.type(at(2)); tenv.constrain(this, new AType(loc, tenv.penv.sym("Fn"), protT, bodyT, 0)); } void ASTCall::constrain(TEnv& tenv) const { FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* retT = tenv.type(this); AType* argsT = new AType(ASTTuple(), loc); for (size_t i = 1; i < size(); ++i) argsT->push_back(tenv.type(at(i))); tenv.constrain(at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0)); } void ASTDefinition::constrain(TEnv& tenv) const { if (size() != 3) throw Error("`def' requires exactly 2 arguments", exp.loc); if (!dynamic_cast(at(1))) throw Error("`def' name is not a symbol", exp.loc); FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* tvar = tenv.type(this); tenv.constrain(at(1), tvar); tenv.constrain(at(2), tvar); } void ASTIf::constrain(TEnv& tenv) const { if (size() < 3) throw Error("`if' requires exactly 3 arguments", exp.loc); if (size() % 2 != 0) throw Error("`if' missing final else clause", exp.loc); FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* retT = tenv.type(this); for (size_t i = 1; i < size(); i += 2) { if (i == size() - 1) { tenv.constrain(at(i), retT); } else { tenv.constrain(at(i), tenv.named("Bool")); tenv.constrain(at(i+1), retT); } } } void ASTPrimitive::constrain(TEnv& tenv) const { const string n = dynamic_cast(at(0))->str(); enum { ARITHMETIC, BINARY, BITWISE, COMPARISON } type; if (n == "+" || n == "-" || n == "*" || n == "/") type = ARITHMETIC; else if (n == "%") type = BINARY; else if (n == "and" || n == "or" || n == "xor") type = BITWISE; else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=") type = COMPARISON; else throw Error((format("unknown primitive `%1%'") % n).str(), exp.loc); FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); switch (type) { case ARITHMETIC: if (size() < 3) throw Error((format("`%1%' requires at least 2 arguments") % n).str(), exp.loc); for (size_t i = 1; i < size(); ++i) tenv.constrain(at(i), tenv.type(this)); break; case BINARY: if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc); tenv.constrain(at(1), tenv.type(this)); tenv.constrain(at(2), tenv.type(this)); break; case BITWISE: if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc); tenv.constrain(this, tenv.named("Bool")); tenv.constrain(at(1), tenv.named("Bool")); tenv.constrain(at(2), tenv.named("Bool")); break; case COMPARISON: if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc); tenv.constrain(this, tenv.named("Bool")); tenv.constrain(at(1), tenv.type(at(2))); break; default: throw Error((format("unknown primitive `%1%'") % n).str(), exp.loc); } } void ASTConsCall::constrain(TEnv& tenv) const { if (size() != 3) throw Error("`cons' requires exactly 2 arguments", loc); AType* t = new AType(loc, tenv.penv.sym("Pair"), 0); for (size_t i = 1; i < size(); ++i) { at(i)->constrain(tenv); t->push_back(tenv.type(at(i))); } tenv.constrain(this, t); } void ASTCarCall::constrain(TEnv& tenv) const { if (size() != 2) throw Error("`car' requires exactly 1 argument", loc); at(1)->constrain(tenv); AType* carT = tenv.var(loc); AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), carT, tenv.var(), 0); tenv.constrain(at(1), pairT); tenv.constrain(this, carT); } void ASTCdrCall::constrain(TEnv& tenv) const { if (size() != 2) throw Error("`cdr' requires exactly 1 argument", loc); at(1)->constrain(tenv); AType* cdrT = tenv.var(loc); AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), tenv.var(), cdrT, 0); tenv.constrain(at(1), pairT); tenv.constrain(this, cdrT); } /*************************************************************************** * Type Inferencing/Substitution * ***************************************************************************/ static void substitute(ASTTuple* tup, AST* from, AST* to) { if (!tup) return; for (size_t i = 0; i < tup->size(); ++i) if (*tup->at(i) == *from) tup->at(i) = to; else substitute(dynamic_cast(tup->at(i)), from, to); } TSubst compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1 { TSubst r; for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) { TSubst::const_iterator d = delta.find(g->second); r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second)); } for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) { if (gamma.find(d->first) == gamma.end()) r.insert(*d); } return r; } void substConstraints(TEnv::Constraints& constraints, AType* s, AType* t) { for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) { TEnv::Constraints::iterator next = c; ++next; if (*c->first == *s) c->first = t; if (*c->second == *s) c->second = t; substitute(c->first, s, t); substitute(c->second, s, t); c = next; } } TSubst TEnv::unify(const Constraints& constraints) // TAPL 22.4 { if (constraints.empty()) return TSubst(); AType* s = constraints.begin()->first; AType* t = constraints.begin()->second; Constraints cp = constraints; cp.erase(cp.begin()); if (*s == *t) { return unify(cp); } else if (s->var() && !t->contains(s)) { substConstraints(cp, s, t); return compose(unify(cp), TSubst(s, t)); } else if (t->var() && !s->contains(t)) { substConstraints(cp, t, s); return compose(unify(cp), TSubst(t, s)); } else if (s->kind == AType::EXPR && s->kind == t->kind && s->size() == t->size()) { for (size_t i = 0; i < s->size(); ++i) { AType* si = dynamic_cast(s->at(i)); AType* ti = dynamic_cast(t->at(i)); if (si && ti) cp.push_back(Constraint(si, ti, si->loc)); } return unify(cp); } else { throw Error((format("type is `%1%' but should be `%2%'") % s->str() % t->str()).str(), constraints.begin()->loc); } } void TEnv::apply(const TSubst& substs) { FOREACH(TSubst::const_iterator, s, substs) FOREACH(Frame::iterator, t, front()) if (*t->second == *s->first) t->second = s->second; else substitute(t->second, s->first, s->second); }