/* Resp Type Inferencing * Copyright (C) 2008-2009 David Robillard * * Resp is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Resp is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Resp. If not, see . */ /** @file * @brief Constrain type of AST expressions */ #include #include "resp.hpp" #define CONSTRAIN_LITERAL(CT, NAME) \ template<> void \ ALiteral::constrain(TEnv& tenv, Constraints& c) const throw(Error) { \ c.constrain(tenv, this, tenv.named(NAME)); \ } // Literal template instantiations CONSTRAIN_LITERAL(int32_t, "Int") CONSTRAIN_LITERAL(float, "Float") CONSTRAIN_LITERAL(bool, "Bool") void AString::constrain(TEnv& tenv, Constraints& c) const throw(Error) { c.constrain(tenv, this, tenv.named("String")); } void ALexeme::constrain(TEnv& tenv, Constraints& c) const throw(Error) { c.constrain(tenv, this, tenv.named("Lexeme")); } void AQuote::constrain(TEnv& tenv, Constraints& c) const throw(Error) { c.constrain(tenv, this, tenv.named("Quote")); (*(begin() + 1))->constrain(tenv, c); } void ASymbol::constrain(TEnv& tenv, Constraints& c) const throw(Error) { const AType** ref = tenv.ref(this); THROW_IF(!ref, loc, (format("undefined symbol `%1%'") % cppstr).str()); c.constrain(tenv, this, *ref); } void ATuple::constrain(TEnv& tenv, Constraints& c) const throw(Error) { AType* t = tup(loc, NULL); FOREACHP(const_iterator, p, this) { (*p)->constrain(tenv, c); t->push_back(const_cast(tenv.var(*p))); } c.constrain(tenv, this, t); } void AFn::constrain(TEnv& tenv, Constraints& c) const throw(Error) { set defs; TEnv::Frame frame; // Add parameters to environment frame AType* protT = tup(loc, NULL); for (ATuple::const_iterator i = prot()->begin(); i != prot()->end(); ++i) { const ASymbol* sym = (*i)->to(); THROW_IF(!sym, (*i)->loc, "parameter name is not a symbol"); THROW_IF(defs.count(sym) != 0, sym->loc, (format("duplicate parameter `%1%'") % sym->str()).str()); defs.insert(sym); const AType* tvar = tenv.fresh(sym); frame.push_back(make_pair(sym, tvar)); protT->push_back(const_cast(tvar)); } const_iterator i = begin() + 1; c.constrain(tenv, *i, protT); // Add internal definitions to environment frame for (++i; i != end(); ++i) { const AST* exp = *i; const ADef* def = exp->to(); if (def) { const ASymbol* sym = def->sym(); THROW_IF(defs.count(sym) != 0, def->loc, (format("`%1%' defined twice") % sym->str()).str()); defs.insert(def->sym()); frame.push_back(make_pair(def->sym(), (AType*)NULL)); } } tenv.push(frame); AST* exp = NULL; for (i = begin() + 2; i != end(); ++i) (exp = *i)->constrain(tenv, c); const AType* bodyT = tenv.var(exp); const AType* fnT = tup(loc, tenv.Fn, protT, bodyT, 0); Object::pool.addRoot(fnT); tenv.pop(); c.constrain(tenv, this, fnT); } void ACall::constrain(TEnv& tenv, Constraints& c) const throw(Error) { for (const_iterator i = begin(); i != end(); ++i) (*i)->constrain(tenv, c); const AType* fnType = tenv.var(head()); if (fnType->kind != AType::VAR) { if (fnType->kind == AType::PRIM || fnType->size() < 2 || fnType->head()->str() != "Fn") throw Error(loc, (format("call to non-function `%1%'") % head()->str()).str()); size_t numArgs = fnType->prot()->size(); THROW_IF(numArgs != size() - 1, loc, (format("expected %1% arguments, got %2%") % numArgs % (size() - 1)).str()); } const AType* retT = tenv.var(this); AType* argsT = tup(loc, 0); for (const_iterator i = begin() + 1; i != end(); ++i) argsT->push_back(const_cast(tenv.var(*i))); c.constrain(tenv, head(), tup(head()->loc, tenv.Fn, argsT, retT, 0)); c.constrain(tenv, this, retT); } void ADef::constrain(TEnv& tenv, Constraints& c) const throw(Error) { THROW_IF(size() != 3, loc, "`def' requires exactly 2 arguments"); const ASymbol* sym = this->sym(); THROW_IF(!sym, loc, "`def' has no symbol") const AType* tvar = tenv.var(body()); tenv.def(sym, tvar); body()->constrain(tenv, c); c.constrain(tenv, sym, tvar); c.constrain(tenv, this, tenv.named("Nothing")); } void AIf::constrain(TEnv& tenv, Constraints& c) const throw(Error) { THROW_IF(size() < 4, loc, "`if' requires at least 3 arguments"); THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause") for (const_iterator i = begin() + 1; i != end(); ++i) (*i)->constrain(tenv, c); const AType* retT = tenv.var(this); for (const_iterator i = begin() + 1; true; ++i) { const_iterator next = i; ++next; if (next == end()) { // final (else) expression c.constrain(tenv, *i, retT); break; } else { c.constrain(tenv, *i, tenv.named("Bool")); c.constrain(tenv, *next, retT); } i = next; // jump 2 each iteration (to the next predicate) } } void AMatch::constrain(TEnv& tenv, Constraints& c) const throw(Error) { THROW_IF(size() < 5, loc, "`match' requires at least 4 arguments"); const AST* matchee = (*(begin() + 1)); const AType* retT = tenv.var(); const AType* matcheeT = NULL;// = tup(loc, tenv.U, 0); matchee->constrain(tenv, c); for (const_iterator i = begin() + 2; i != end();) { const AST* exp = *i++; const ATuple* pattern = exp->to(); THROW_IF(!pattern, exp->loc, "pattern expression expected"); const ASymbol* name = (*pattern->begin())->to(); THROW_IF(!name, (*pattern->begin())->loc, "pattern does not start with a symbol"); const AType* consT = *tenv.ref(name); if (!matcheeT) { const AType* headT = consT->head()->as(); matcheeT = tup(loc, const_cast(headT), 0); } THROW_IF(i == end(), pattern->loc, "missing pattern body"); const AST* body = *i++; body->constrain(tenv, c); c.constrain(tenv, body, retT); } c.constrain(tenv, this, retT); c.constrain(tenv, matchee, matcheeT); } void ADefType::constrain(TEnv& tenv, Constraints& c) const throw(Error) { THROW_IF(size() < 3, loc, "`def-type' requires at least 2 arguments"); const_iterator i = begin() + 1; const ATuple* prot = (*i)->to(); THROW_IF(!prot, (*i)->loc, "first argument of `def-type' is not a tuple"); const ASymbol* sym = (*prot->begin())->as(); THROW_IF(!sym, (*prot->begin())->loc, "type name is not a symbol"); THROW_IF(tenv.ref(sym), loc, "type redefinition"); AType* type = tup(loc, tenv.U, 0); for (const_iterator i = begin() + 2; i != end(); ++i) { const ATuple* exp = (*i)->as(); const ASymbol* tag = (*exp->begin())->as(); AType* consT = new AType(exp->loc, AType::EXPR); consT->push_back(new AType(const_cast(sym), AType::NAME)); for (ATuple::const_iterator i = exp->begin(); i != exp->end(); ++i) { const ASymbol* sym = (*i)->to(); THROW_IF(!sym, (*i)->loc, "type expression element is not a symbol"); consT->push_back(new AType(const_cast(sym), AType::NAME)); } type->push_back(consT); tenv.def(tag, consT); } tenv.def(sym, type); } void ACons::constrain(TEnv& tenv, Constraints& c) const throw(Error) { const ASymbol* sym = (*begin())->as(); const AType* type = NULL; for (const_iterator i = begin() + 1; i != end(); ++i) (*i)->constrain(tenv, c); if (sym->cppstr == "Tup") { AType* tupT = tup(loc, tenv.Tup, 0); for (const_iterator i = begin() + 1; i != end(); ++i) { tupT->push_back(const_cast(tenv.var(*i))); } type = tupT; } else { const AType** consTRef = tenv.ref(sym); THROW_IF(!consTRef, loc, (format("call to undefined constructor `%1%'") % sym->cppstr).str()); const AType* consT = *consTRef; type = tup(loc, const_cast(consT->head()->as()), 0); } c.constrain(tenv, this, type); } void ADot::constrain(TEnv& tenv, Constraints& c) const throw(Error) { THROW_IF(size() != 3, loc, "`.' requires exactly 2 arguments"); const_iterator i = begin(); AST* obj = *++i; ALiteral* idx = (*++i)->to*>(); THROW_IF(!idx, loc, "the 2nd argument to `.' must be a literal integer"); obj->constrain(tenv, c); const AType* retT = tenv.var(this); c.constrain(tenv, this, retT); AType* objT = tup(loc, tenv.Tup, 0); for (int i = 0; i < idx->val; ++i) objT->push_back(const_cast(tenv.var())); objT->push_back(const_cast(retT)); objT->push_back(new AType(obj->loc, AType::DOTS)); c.constrain(tenv, obj, objT); } void APrimitive::constrain(TEnv& tenv, Constraints& c) const throw(Error) { const string n = head()->to()->str(); enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type; if (n == "+" || n == "-" || n == "*" || n == "/") type = ARITHMETIC; else if (n == "%") type = BINARY; else if (n == "and" || n == "or" || n == "xor") type = LOGICAL; else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=") type = COMPARISON; else throw Error(loc, (format("unknown primitive `%1%'") % n).str()); const_iterator i = begin(); for (++i; i != end(); ++i) (*i)->constrain(tenv, c); i = begin(); const AType* var = NULL; switch (type) { case ARITHMETIC: if (size() < 3) throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str()); for (++i; i != end(); ++i) c.constrain(tenv, *i, tenv.var(this)); break; case BINARY: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, *++i, tenv.var(this)); c.constrain(tenv, *++i, tenv.var(this)); break; case LOGICAL: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, this, tenv.named("Bool")); c.constrain(tenv, *++i, tenv.named("Bool")); c.constrain(tenv, *++i, tenv.named("Bool")); break; case COMPARISON: if (size() != 3) throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str()); var = tenv.var(*++i); c.constrain(tenv, this, tenv.named("Bool")); c.constrain(tenv, *++i, var); break; default: throw Error(loc, (format("unknown primitive `%1%'") % n).str()); } }