/* Resp Type Inferencing * Copyright (C) 2008-2009 David Robillard * * Resp is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Resp is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Resp. If not, see . */ /** @file * @brief Constrain type of AST expressions */ #include #include "resp.hpp" static void constrain_symbol(TEnv& tenv, Constraints& c, const ASymbol* sym) throw(Error) { const AType** ref = tenv.ref(sym); THROW_IF(!ref, sym->loc, (format("undefined symbol `%1%'") % sym->cppstr).str()); c.constrain(tenv, sym, *ref); } static void constrain_fn(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { set defs; TEnv::Frame frame; const ATuple* const prot = call->prot(); // Add parameters to environment frame TList protT; for (ATuple::const_iterator i = prot->begin(); i != prot->end(); ++i) { const ASymbol* sym = (*i)->to_symbol(); THROW_IF(!sym, (*i)->loc, "parameter name is not a symbol"); THROW_IF(defs.count(sym) != 0, sym->loc, (format("duplicate parameter `%1%'") % sym->str()).str()); defs.insert(sym); const AType* tvar = tenv.fresh(sym); frame.push_back(make_pair(sym, tvar)); protT.push_back(tvar); } protT.head->loc = call->loc; ATuple::const_iterator i = call->iter_at(1); c.constrain(tenv, *i, protT); // Add internal definitions to environment frame for (++i; i != call->end(); ++i) { const AST* exp = *i; const ATuple* call = exp->to_tuple(); if (call && is_form(call, "def")) { const ASymbol* sym = call->list_ref(1)->as_symbol(); THROW_IF(defs.count(sym) != 0, call->loc, (format("`%1%' defined twice") % sym->str()).str()); defs.insert(sym); frame.push_back(make_pair(sym, (AType*)NULL)); } } tenv.push(frame); const AST* exp = NULL; for (i = call->iter_at(2); i != call->end(); ++i) { exp = *i; resp_constrain(tenv, c, exp); } const AType* bodyT = tenv.var(exp); const AType* fnT = tup(call->loc, tenv.Fn, protT.head, bodyT, 0); Object::pool.addRoot(fnT); tenv.pop(); c.constrain(tenv, call, fnT); } static void constrain_def(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() != 3, call->loc, "`def' requires exactly 2 arguments"); const ASymbol* const sym = call->list_ref(1)->as_symbol(); THROW_IF(!sym, call->loc, "`def' has no symbol") const AST* const body = call->list_ref(2); const AType* tvar = tenv.var(body); tenv.def(sym, tvar); resp_constrain(tenv, c, body); c.constrain(tenv, sym, tvar); c.constrain(tenv, call, tenv.named("Nothing")); } static void constrain_def_type(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() < 3, call->loc, "`def-type' requires at least 2 arguments"); ATuple::const_iterator i = call->iter_at(1); const ATuple* prot = (*i)->to_tuple(); THROW_IF(!prot, (*i)->loc, "first argument of `def-type' is not a tuple"); const ASymbol* sym = (*prot->begin())->as_symbol(); THROW_IF(!sym, (*prot->begin())->loc, "type name is not a symbol"); THROW_IF(tenv.ref(sym), call->loc, "type redefinition"); TList type(new AType(tenv.U, NULL, call->loc)); for (ATuple::const_iterator i = call->iter_at(2); i != call->end(); ++i) { const ATuple* exp = (*i)->as_tuple(); const ASymbol* tag = (*exp->begin())->as_symbol(); TList consT; consT.push_back(new AType(sym, AType::NAME)); for (ATuple::const_iterator i = exp->begin(); i != exp->end(); ++i) { const ASymbol* sym = (*i)->to_symbol(); THROW_IF(!sym, (*i)->loc, "type expression element is not a symbol"); consT.push_back(new AType(sym, AType::NAME)); } consT.head->loc = exp->loc; type.push_back(consT); tenv.def(tag, consT); } tenv.def(sym, type); } static void constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() < 5, call->loc, "`match' requires at least 4 arguments"); const AST* matchee = call->list_ref(1); const AType* retT = tenv.var(); const AType* matcheeT = NULL;// = tup(loc, tenv.U, 0); resp_constrain(tenv, c, matchee); for (ATuple::const_iterator i = call->iter_at(2); i != call->end();) { const AST* exp = *i++; const ATuple* pattern = exp->to_tuple(); THROW_IF(!pattern, exp->loc, "pattern expression expected"); const ASymbol* name = (*pattern->begin())->to_symbol(); THROW_IF(!name, (*pattern->begin())->loc, "pattern does not start with a symbol"); const AType* consT = *tenv.ref(name); if (!matcheeT) { const AType* headT = consT->head()->as_type(); matcheeT = new AType(headT, 0, call->loc); } THROW_IF(i == call->end(), pattern->loc, "missing pattern body"); const AST* body = *i++; resp_constrain(tenv, c, body); c.constrain(tenv, body, retT); } c.constrain(tenv, call, retT); c.constrain(tenv, matchee, matcheeT); } static void constrain_if(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() < 4, call->loc, "`if' requires at least 3 arguments"); THROW_IF(call->list_len() % 2 != 0, call->loc, "`if' missing final else clause"); for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) resp_constrain(tenv, c, *i); const AType* retT = tenv.var(call); for (ATuple::const_iterator i = call->iter_at(1); true; ++i) { ATuple::const_iterator next = i; ++next; if (next == call->end()) { // final (else) expression c.constrain(tenv, *i, retT); break; } else { c.constrain(tenv, *i, tenv.named("Bool")); c.constrain(tenv, *next, retT); } i = next; // jump 2 each iteration (to the next predicate) } } static void constrain_cons(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { const ASymbol* sym = (*call->begin())->as_symbol(); const AType* type = NULL; for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) resp_constrain(tenv, c, *i); if (sym->cppstr == "Tup") { TList tupT(new AType(tenv.Tup, NULL, call->loc)); for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) { tupT.push_back(tenv.var(*i)); } type = tupT; } else { const AType** consTRef = tenv.ref(sym); THROW_IF(!consTRef, call->loc, (format("call to undefined constructor `%1%'") % sym->cppstr).str()); const AType* consT = *consTRef; type = new AType(consT->head()->as_type(), 0, call->loc); } c.constrain(tenv, call, type); } static void constrain_dot(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() != 3, call->loc, "`.' requires exactly 2 arguments"); ATuple::const_iterator i = call->begin(); const AST* obj = *++i; const AST* idx_ast = *++i; THROW_IF(idx_ast->tag() != T_INT32, call->loc, "the 2nd argument to `.' must be a literal integer"); const ALiteral* idx = (ALiteral*)idx_ast; resp_constrain(tenv, c, obj); const AType* retT = tenv.var(call); c.constrain(tenv, call, retT); TList objT(new AType(tenv.Tup, NULL, call->loc)); for (int i = 0; i < idx->val; ++i) objT.push_back(tenv.var()); objT.push_back(retT); objT.push_back(new AType(obj->loc, AType::DOTS)); c.constrain(tenv, obj, objT); } static void constrain_call(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { const AST* const head = call->head(); for (ATuple::const_iterator i = call->begin(); i != call->end(); ++i) resp_constrain(tenv, c, *i); const AType* fnType = tenv.var(head); if (fnType->kind != AType::VAR) { if (fnType->kind == AType::PRIM || fnType->list_len() < 2 || fnType->head()->str() != "Fn") throw Error(call->loc, (format("call to non-function `%1%'") % head->str()).str()); size_t numArgs = fnType->prot()->list_len(); THROW_IF(numArgs != call->list_len() - 1, call->loc, (format("expected %1% arguments, got %2%") % numArgs % (call->list_len() - 1)).str()); } const AType* retT = tenv.var(call); TList argsT; for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) argsT.push_back(tenv.var(*i)); argsT.head->loc = call->loc; c.constrain(tenv, head, tup(head->loc, tenv.Fn, argsT.head, retT, 0)); c.constrain(tenv, call, retT); } static void constrain_primitive(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { const string n = call->head()->to_symbol()->str(); enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type; if (n == "+" || n == "-" || n == "*" || n == "/") type = ARITHMETIC; else if (n == "%") type = BINARY; else if (n == "and" || n == "or" || n == "xor") type = LOGICAL; else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=") type = COMPARISON; else throw Error(call->loc, (format("unknown primitive `%1%'") % n).str()); ATuple::const_iterator i = call->begin(); for (++i; i != call->end(); ++i) resp_constrain(tenv, c, *i); i = call->begin(); const AType* var = NULL; switch (type) { case ARITHMETIC: if (call->list_len() < 3) throw Error(call->loc, (format("`%1%' requires at least 2 arguments") % n).str()); for (++i; i != call->end(); ++i) c.constrain(tenv, *i, tenv.var(call)); break; case BINARY: if (call->list_len() != 3) throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, *++i, tenv.var(call)); c.constrain(tenv, *++i, tenv.var(call)); break; case LOGICAL: if (call->list_len() != 3) throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, call, tenv.named("Bool")); c.constrain(tenv, *++i, tenv.named("Bool")); c.constrain(tenv, *++i, tenv.named("Bool")); break; case COMPARISON: if (call->list_len() != 3) throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str()); var = tenv.var(*++i); c.constrain(tenv, call, tenv.named("Bool")); c.constrain(tenv, *++i, var); break; default: throw Error(call->loc, (format("unknown primitive `%1%'") % n).str()); } } static void constrain_tuple(TEnv& tenv, Constraints& c, const ATuple* tup) throw(Error) { const ASymbol* const sym = tup->head()->to_symbol(); if (!sym) { constrain_call(tenv, c, tup); return; } const std::string form = sym->cppstr; if (is_primitive(tenv.penv, tup)) constrain_primitive(tenv, c, tup); else if (form == "fn") constrain_fn(tenv, c, tup); else if (form == "def") constrain_def(tenv, c, tup); else if (form == "def-type") constrain_def_type(tenv, c, tup); else if (form == "match") constrain_match(tenv, c, tup); else if (form == "if") constrain_if(tenv, c, tup); else if (form == "cons" || isupper(form[0])) constrain_cons(tenv, c, tup); else if (form == ".") constrain_dot(tenv, c, tup); else constrain_call(tenv, c, tup); } void resp_constrain(TEnv& tenv, Constraints& c, const AST* ast) throw(Error) { switch (ast->tag()) { case T_UNKNOWN: case T_TYPE: break; case T_BOOL: c.constrain(tenv, ast, tenv.named("Bool")); break; case T_FLOAT: c.constrain(tenv, ast, tenv.named("Float")); break; case T_INT32: c.constrain(tenv, ast, tenv.named("Int")); break; case T_STRING: c.constrain(tenv, ast, tenv.named("String")); break; case T_SYMBOL: constrain_symbol(tenv, c, ast->as_symbol()); break; case T_TUPLE: constrain_tuple(tenv, c, ast->as_tuple()); break; } }