/* Tuplr Type Inferencing * Copyright (C) 2008-2009 David Robillard * * Tuplr is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Tuplr is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Tuplr. If not, see . */ /** @file * @brief Constrain type of AST expressions */ #include #include "tuplr.hpp" #define CONSTRAIN_LITERAL(CT, NAME) \ template<> void \ ALiteral::constrain(TEnv& tenv, Constraints& c) const { \ c.constrain(tenv, this, tenv.named(NAME)); \ } // Literal template instantiations CONSTRAIN_LITERAL(int32_t, "Int") CONSTRAIN_LITERAL(float, "Float") CONSTRAIN_LITERAL(bool, "Bool") void AString::constrain(TEnv& tenv, Constraints& c) const { c.constrain(tenv, this, tenv.named("String")); } void ASymbol::constrain(TEnv& tenv, Constraints& c) const { AType** ref = tenv.ref(this); THROW_IF(!ref, loc, (format("undefined symbol `%1%'") % cppstr).str()); c.constrain(tenv, this, *ref); } void ATuple::constrain(TEnv& tenv, Constraints& c) const { AType* t = tup(loc, NULL); FOREACHP(const_iterator, p, this) { (*p)->constrain(tenv, c); t->push_back(tenv.var(*p)); } c.constrain(tenv, this, t); } void AFn::constrain(TEnv& tenv, Constraints& c) const { set defs; TEnv::Frame frame; // Add parameters to environment frame AType* protT = tup(loc, NULL); for (ATuple::const_iterator i = prot()->begin(); i != prot()->end(); ++i) { const ASymbol* sym = (*i)->to(); THROW_IF(!sym, (*i)->loc, "parameter name is not a symbol"); THROW_IF(defs.count(sym) != 0, sym->loc, (format("duplicate parameter `%1%'") % sym->str()).str()); defs.insert(sym); AType* tvar = tenv.fresh(sym); frame.push_back(make_pair(sym, tvar)); protT->push_back(tvar); } const_iterator i = begin() + 1; c.constrain(tenv, *i, protT); // Add internal definitions to environment frame for (++i; i != end(); ++i) { const AST* exp = *i; const ADef* def = exp->to(); if (def) { const ASymbol* sym = def->sym(); THROW_IF(defs.count(sym) != 0, def->loc, (format("`%1%' defined twice") % sym->str()).str()); defs.insert(def->sym()); frame.push_back(make_pair(def->sym(), (AType*)NULL)); } } tenv.push(frame); c.constrain(tenv, this, tenv.var()); AST* exp = NULL; for (i = begin() + 2; i != end(); ++i) (exp = *i)->constrain(tenv, c); AType* bodyT = tenv.var(exp); AType* fnT = tup(loc, tenv.Fn, protT, bodyT, 0); Object::pool.addRoot(fnT); tenv.pop(); c.constrain(tenv, this, fnT); } void ACall::constrain(TEnv& tenv, Constraints& c) const { for (const_iterator i = begin(); i != end(); ++i) (*i)->constrain(tenv, c); const AType* fnType = tenv.var(head()); if (fnType->kind != AType::VAR) { if (fnType->kind == AType::PRIM || fnType->size() < 2 || fnType->head()->str() != "Fn") throw Error(loc, (format("call to non-function `%1%'") % head()->str()).str()); size_t numArgs = fnType->prot()->size(); THROW_IF(numArgs != size() - 1, loc, (format("expected %1% arguments, got %2%") % numArgs % (size() - 1)).str()); } AType* retT = tenv.var(); AType* argsT = tup(loc, 0); for (const_iterator i = begin() + 1; i != end(); ++i) argsT->push_back(tenv.var(*i)); c.constrain(tenv, head(), tup(head()->loc, tenv.Fn, argsT, retT, 0)); c.constrain(tenv, this, retT); } void ADef::constrain(TEnv& tenv, Constraints& c) const { THROW_IF(size() != 3, loc, "`def' requires exactly 2 arguments"); const ASymbol* sym = this->sym(); THROW_IF(!sym, loc, "`def' has no symbol") AType* tvar = tenv.var(body()); tenv.def(sym, tvar); body()->constrain(tenv, c); c.constrain(tenv, sym, tvar); c.constrain(tenv, this, tenv.named("Nothing")); } void AIf::constrain(TEnv& tenv, Constraints& c) const { THROW_IF(size() < 4, loc, "`if' requires at least 3 arguments"); THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause") for (const_iterator i = begin() + 1; i != end(); ++i) (*i)->constrain(tenv, c); AType* retT = tenv.var(this); for (const_iterator i = begin() + 1; true; ++i) { const_iterator next = i; ++next; if (next == end()) { // final (else) expression c.constrain(tenv, *i, retT); break; } else { c.constrain(tenv, *i, tenv.named("Bool")); c.constrain(tenv, *next, retT); } i = next; // jump 2 each iteration (to the next predicate) } } void ACons::constrain(TEnv& tenv, Constraints& c) const { AType* type = tup(loc, tenv.Tup, 0); for (const_iterator i = begin() + 1; i != end(); ++i) { (*i)->constrain(tenv, c); type->push_back(tenv.var(*i)); } c.constrain(tenv, this, type); } void ADot::constrain(TEnv& tenv, Constraints& c) const { THROW_IF(size() != 3, loc, "`.' requires exactly 2 arguments"); const_iterator i = begin(); AST* obj = *++i; ALiteral* idx = (*++i)->to*>(); THROW_IF(!idx, loc, "the 2nd argument to `.' must be a literal integer"); obj->constrain(tenv, c); AType* retT = tenv.var(this); c.constrain(tenv, this, retT); AType* objT = tup(loc, tenv.Tup, 0); for (int i = 0; i < idx->val; ++i) objT->push_back(tenv.var()); objT->push_back(retT); objT->push_back(new AType(obj->loc, AType::DOTS)); c.constrain(tenv, obj, objT); } void APrimitive::constrain(TEnv& tenv, Constraints& c) const { const string n = head()->to()->str(); enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type; if (n == "+" || n == "-" || n == "*" || n == "/") type = ARITHMETIC; else if (n == "%") type = BINARY; else if (n == "and" || n == "or" || n == "xor") type = LOGICAL; else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=") type = COMPARISON; else throw Error(loc, (format("unknown primitive `%1%'") % n).str()); const_iterator i = begin(); for (++i; i != end(); ++i) (*i)->constrain(tenv, c); i = begin(); AType* var = NULL; switch (type) { case ARITHMETIC: if (size() < 3) throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str()); for (++i; i != end(); ++i) c.constrain(tenv, *i, tenv.var(this)); break; case BINARY: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, *++i, tenv.var(this)); c.constrain(tenv, *++i, tenv.var(this)); break; case LOGICAL: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, this, tenv.named("Bool")); c.constrain(tenv, *++i, tenv.named("Bool")); c.constrain(tenv, *++i, tenv.named("Bool")); break; case COMPARISON: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); var = tenv.var(*++i); c.constrain(tenv, this, tenv.named("Bool")); c.constrain(tenv, *++i, var); break; default: throw Error(loc, (format("unknown primitive `%1%'") % n).str()); } }