diff options
Diffstat (limited to 'src/typing.cpp')
-rw-r--r-- | src/typing.cpp | 232 |
1 files changed, 232 insertions, 0 deletions
diff --git a/src/typing.cpp b/src/typing.cpp new file mode 100644 index 0000000..5791fdc --- /dev/null +++ b/src/typing.cpp @@ -0,0 +1,232 @@ +/* Tuplr Type Inferencing + * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net> + * + * Tuplr is free software: you can redistribute it and/or modify it under + * the terms of the GNU Affero General Public License as published by the + * Free Software Foundation, either version 3 of the License, or (at your + * option) any later version. + * + * Tuplr is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Tuplr. If not, see <http://www.gnu.org/licenses/>. + */ + +#include <set> +#include "tuplr.hpp" + +void +Constraints::constrain(TEnv& tenv, const AST* o, AType* t) +{ + assert(!o->to<const AType*>()); + push_back(Constraint(tenv.var(o), t, o->loc)); +} + + +/*************************************************************************** + * AST Type Constraints * + ***************************************************************************/ + +void +ASymbol::constrain(TEnv& tenv, Constraints& c) const +{ + addr = tenv.lookup(this); + if (!addr) + throw Error(loc, (format("undefined symbol `%1%'") % cppstr).str()); + c.push_back(Constraint(tenv.var(this), tenv.deref(addr).second, loc)); +} + +void +ATuple::constrain(TEnv& tenv, Constraints& c) const +{ + AType* t = tup<AType>(loc, NULL); + FOREACH(const_iterator, p, *this) { + (*p)->constrain(tenv, c); + t->push_back(tenv.var(*p)); + } + c.push_back(Constraint(tenv.var(this), t, loc)); +} + +void +AFn::constrain(TEnv& tenv, Constraints& c) const +{ + const AType* genericType; + TEnv::GenericTypes::const_iterator gt = tenv.genericTypes.find(this); + if (gt != tenv.genericTypes.end()) { + genericType = gt->second; + } else { + set<ASymbol*> defined; + TEnv::Frame frame; + + // Add parameters to environment frame + for (size_t i = 0; i < prot()->size(); ++i) { + ASymbol* sym = prot()->at(i)->to<ASymbol*>(); + if (!sym) + throw Error(prot()->at(i)->loc, "parameter name is not a symbol"); + if (defined.find(sym) != defined.end()) + throw Error(sym->loc, (format("duplicate parameter `%1%'") % sym->str()).str()); + defined.insert(sym); + frame.push_back(make_pair(sym, make_pair((AST*)NULL, (AType*)NULL))); + } + + // Add internal definitions to environment frame + size_t e = 2; + for (; e < size(); ++e) { + AST* exp = at(e); + ADef* def = exp->to<ADef*>(); + if (def) { + ASymbol* sym = def->sym(); + if (defined.find(sym) != defined.end()) + throw Error(def->loc, (format("`%1%' defined twice") % sym->str()).str()); + defined.insert(def->sym()); + frame.push_back(make_pair(def->sym(), make_pair(def->at(2), (AType*)NULL))); + } + } + + tenv.push(frame); + + Constraints cp; + cp.push_back(Constraint(tenv.var(this), tenv.var(), loc)); + + AType* protT = tup<AType>(loc, NULL); + for (size_t i = 0; i < prot()->size(); ++i) { + AType* tvar = tenv.fresh(prot()->at(i)->to<ASymbol*>()); + protT->push_back(tvar); + assert(frame[i].first == prot()->at(i)); + frame[i].second.first = prot()->at(i); + frame[i].second.second = tvar; + } + c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc)); + + for (size_t i = 2; i < size(); ++i) + at(i)->constrain(tenv, cp); + + AType* bodyT = tenv.var(at(e-1)); + Subst tsubst = TEnv::unify(cp); + genericType = tup<AType>(loc, tenv.penv.sym("Fn"), + tsubst.apply(protT), tsubst.apply(bodyT), 0); + tenv.genericTypes.insert(make_pair(this, genericType)); + Object::pool.addRoot(genericType); + + tenv.pop(); + subst = tsubst; + } + + AType* t = new AType(*genericType); // FIXME: deep copy + c.constrain(tenv, this, t); +} + +void +ACall::constrain(TEnv& tenv, Constraints& c) const +{ + at(0)->constrain(tenv, c); + for (size_t i = 1; i < size(); ++i) + at(i)->constrain(tenv, c); + + AST* callee = tenv.resolve(at(0)); + AFn* closure = callee->to<AFn*>(); + if (closure) { + if (size() - 1 != closure->prot()->size()) + throw Error(loc, "incorrect number of arguments"); + TEnv::GenericTypes::iterator gt = tenv.genericTypes.find(closure); + if (gt != tenv.genericTypes.end()) { + for (size_t i = 1; i < size(); ++i) + c.constrain(tenv, at(i), gt->second->at(1)->as<ATuple*>()->at(i-1)->as<AType*>()); + AType* retT = tenv.var(this); + c.constrain(tenv, at(0), tup<AType>(at(0)->loc, tenv.penv.sym("Fn"), tenv.var(), retT, 0)); + c.constrain(tenv, this, retT); + return; + } + } + AType* argsT = tup<AType>(loc, 0); + for (size_t i = 1; i < size(); ++i) + argsT->push_back(tenv.var(at(i))); + AType* retT = tenv.var(); + c.constrain(tenv, at(0), tup<AType>(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0)); + c.constrain(tenv, this, retT); +} + +void +ADef::constrain(TEnv& tenv, Constraints& c) const +{ + THROW_IF(size() != 3, loc, "`def' requires exactly 2 arguments"); + const ASymbol* sym = this->sym(); + THROW_IF(!sym, loc, "`def' has no symbol") + + AType* tvar = tenv.var(at(2)); + tenv.def(sym, make_pair(at(2), tvar)); + at(2)->constrain(tenv, c); + c.constrain(tenv, this, tvar); +} + +void +AIf::constrain(TEnv& tenv, Constraints& c) const +{ + THROW_IF(size() < 4, loc, "`if' requires at least 3 arguments"); + THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause") + for (size_t i = 1; i < size(); ++i) + at(i)->constrain(tenv, c); + AType* retT = tenv.var(this); + for (size_t i = 1; i < size(); i += 2) { + if (i == size() - 1) { + c.constrain(tenv, at(i), retT); + } else { + c.constrain(tenv, at(i), tenv.named("Bool")); + c.constrain(tenv, at(i+1), retT); + } + } +} + +void +APrimitive::constrain(TEnv& tenv, Constraints& c) const +{ + const string n = at(0)->to<ASymbol*>()->str(); + enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type; + if (n == "+" || n == "-" || n == "*" || n == "/") + type = ARITHMETIC; + else if (n == "%") + type = BINARY; + else if (n == "and" || n == "or" || n == "xor") + type = LOGICAL; + else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=") + type = COMPARISON; + else + throw Error(loc, (format("unknown primitive `%1%'") % n).str()); + + for (size_t i = 1; i < size(); ++i) + at(i)->constrain(tenv, c); + + switch (type) { + case ARITHMETIC: + if (size() < 3) + throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str()); + for (size_t i = 1; i < size(); ++i) + c.constrain(tenv, at(i), tenv.var(this)); + break; + case BINARY: + if (size() != 3) + throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); + c.constrain(tenv, at(1), tenv.var(this)); + c.constrain(tenv, at(2), tenv.var(this)); + break; + case LOGICAL: + if (size() != 3) + throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); + c.constrain(tenv, this, tenv.named("Bool")); + c.constrain(tenv, at(1), tenv.named("Bool")); + c.constrain(tenv, at(2), tenv.named("Bool")); + break; + case COMPARISON: + if (size() != 3) + throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); + c.constrain(tenv, this, tenv.named("Bool")); + c.constrain(tenv, at(1), tenv.var(at(2))); + break; + default: + throw Error(loc, (format("unknown primitive `%1%'") % n).str()); + } +} + |