/* Resp: A programming language * Copyright (C) 2008-2009 David Robillard * * Resp is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Resp is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Resp. If not, see . */ /** @file * @brief Constrain type of AST expressions */ #include #include #include "resp.hpp" static void constrain_symbol(TEnv& tenv, Constraints& c, const ASymbol* sym) throw(Error) { const AST** ref = tenv.ref(sym); THROW_IF(!ref, sym->loc, (format("undefined symbol `%1%'") % sym->sym()).str()); c.constrain(tenv, sym, *ref); } static void constrain_cons(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { const ASymbol* name = (*call->begin())->as_symbol(); // Constrain each argument for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) resp_constrain(tenv, c, *i); // ::= ?Targi if (!strcmp(name->sym(), "Tup")) { // Build a type expression like (Tup ?Targ1 ...) List tupT(new ATuple(name, NULL, call->loc)); for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) { tupT.push_back(tenv.var(*i)); } c.constrain(tenv, call, tupT); } else { // Look up constructor and use its type TEnv::Tags::const_iterator tag = tenv.tags.find(name->str()); THROW_IF(tag == tenv.tags.end(), name->loc, (format("undefined constructor `%1%'") % name->sym()).str()); // Build a substitution for every tvar in the constructor pattern Subst subst; const ATuple* expr = tag->second.expr->as_tuple(); ATuple::const_iterator e = expr->iter_at(1); for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i, ++e) { const ASymbol* sym = (*e)->to_symbol(); if (sym && !isupper(sym->str()[0])) { // Argument corresponds to type variable in constructor pattern subst.add(*e, tenv.var(*i)); } } // Substitute tvar symbols with the tvar for the corresponding argument const AST* pattern = subst.apply(tag->second.type); // Replace remaining tvar symbols with a free tvar for (ATuple::const_iterator i = pattern->as_tuple()->iter_at(1); i != pattern->as_tuple()->end(); ++i) { const ASymbol* sym = (*i)->to_symbol(); if (sym && islower(sym->str()[0])) { subst.add(sym, tenv.var()); } } // Constrain every argument to the corresponding pattern element e = expr->iter_at(1); for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i, ++e) { c.constrain(tenv, *i, subst.apply(*e)); } c.constrain(tenv, call, subst.apply(pattern)); } } static void constrain_dot(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() != 3, call->loc, "`.' requires exactly 2 arguments"); ATuple::const_iterator i = call->begin(); const AST* obj = *++i; const AST* idx = *++i; THROW_IF(idx->tag() != T_INT32, call->loc, "the 2nd argument to `.' must be a literal integer"); resp_constrain(tenv, c, obj); const AST* retT = tenv.var(call); c.constrain(tenv, call, retT); List objT(new ATuple(tenv.Tup, NULL, call->loc)); for (int i = 0; i < ((ALiteral*)idx)->val; ++i) objT.push_back(tenv.var()); objT.push_back(retT); objT.push_back(tenv.Dots); c.constrain(tenv, obj, objT); } static void constrain_def(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() != 3, call->loc, "`def' requires exactly 2 arguments"); THROW_IF(!call->frst()->to_symbol(), call->frst()->loc, "`def' name is not a symbol"); const ASymbol* const sym = call->list_ref(1)->as_symbol(); THROW_IF(!sym, call->loc, "`def' has no symbol") const AST* const body = call->list_ref(2); const AST* tvar = tenv.var(body); tenv.def(sym, tvar); resp_constrain(tenv, c, body); c.constrain(tenv, sym, tvar); c.constrain(tenv, call, tenv.named("Nothing")); } static void constrain_def_type(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() < 3, call->loc, "`def-type' requires at least 2 arguments"); ATuple::const_iterator i = call->iter_at(1); const ATuple* prot = (*i)->to_tuple(); THROW_IF(!prot, (*i)->loc, "first argument of `def-type' is not a tuple"); const ASymbol* sym = (*prot->begin())->as_symbol(); THROW_IF(!sym, (*prot->begin())->loc, "type name is not a symbol"); //THROW_IF(tenv.ref(sym), call->loc, "type redefinition"); List type(call->loc, tenv.penv.sym("Lambda"), prot->rst(), NULL); for (ATuple::const_iterator i = call->iter_at(2); i != call->end(); ++i) { const ATuple* exp = (*i)->as_tuple(); const ASymbol* tag = (*exp->begin())->as_symbol(); tenv.tags.insert(std::make_pair(tag->str(), TEnv::Constructor(exp, prot))); type.push_back(exp); } tenv.def(sym, type); } static void constrain_fn(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { set defs; TEnv::Frame frame; // Add parameters to environment frame List protT; for (const auto& i : *call->prot()) { const ASymbol* sym = i->to_symbol(); THROW_IF(!sym, i->loc, "parameter name is not a symbol"); THROW_IF(defs.count(sym) != 0, sym->loc, (format("duplicate parameter `%1%'") % sym->str()).str()); defs.insert(sym); const AST* tvar = tenv.fresh(sym); frame.push_back(make_pair(sym->sym(), tvar)); protT.push_back(tvar); } if (!protT.head) { protT.head = new ATuple(NULL, NULL, call->loc); } else { protT.head->loc = call->loc; } ATuple::const_iterator i = call->iter_at(1); c.constrain(tenv, *i, protT); // Add internal definitions to environment frame for (++i; i != call->end(); ++i) { const AST* exp = *i; const ATuple* call = exp->to_tuple(); if (call && is_form(call, "define")) { const ASymbol* sym = call->list_ref(1)->as_symbol(); THROW_IF(defs.count(sym) != 0, call->loc, (format("`%1%' defined twice") % sym->str()).str()); defs.insert(sym); frame.push_back(make_pair(sym->sym(), (AST*)NULL)); } } tenv.push(frame); const AST* exp = NULL; for (i = call->iter_at(2); i != call->end(); ++i) { exp = *i; resp_constrain(tenv, c, exp); } const AST* bodyT = tenv.var(exp); const ATuple* fnT = tup(call->loc, tenv.Fn, protT.head, bodyT, 0); Object::pool.addRoot(fnT); tenv.pop(); c.constrain(tenv, call, fnT); } static void constrain_if(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() < 4, call->loc, "`if' requires at least 3 arguments"); THROW_IF(call->list_len() % 2 != 0, call->loc, "`if' missing final else clause"); for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) resp_constrain(tenv, c, *i); const AST* retT = tenv.var(call); for (ATuple::const_iterator i = call->iter_at(1); true; ++i) { ATuple::const_iterator next = i; ++next; if (next == call->end()) { // final (else) expression c.constrain(tenv, *i, retT); break; } else { c.constrain(tenv, *i, tenv.named("Bool")); c.constrain(tenv, *next, retT); } i = next; // jump 2 each iteration (to the next predicate) } } static void constrain_let(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() < 3, call->loc, "`let' requires at least 2 arguments"); const ATuple* vars = call->list_ref(1)->to_tuple(); THROW_IF(!vars, call->list_ref(1)->loc, "first argument of `let' is not a list"); TEnv::Frame frame; for (ATuple::const_iterator i = vars->begin(); i != vars->end(); ++i) { const ASymbol* sym = (*i)->to_symbol(); THROW_IF(!sym, (*i)->loc, "`let' binding name is not a symbol"); ATuple::const_iterator val = ++i; THROW_IF(val == vars->end(), sym->loc, "`let' variable missing value"); resp_constrain(tenv, c, *val); const AST* tvar = tenv.var(*val); frame.push_back(make_pair(sym->sym(), tvar)); c.constrain(tenv, sym, tvar); //c.constrain(tenv, *val, tvar); } tenv.push(frame); for (ATuple::const_iterator i = call->iter_at(2); i != call->end(); ++i) resp_constrain(tenv, c, *i); c.constrain(tenv, call, tenv.var(call->list_last())); tenv.pop(); } static void constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() < 5, call->loc, "`match' requires at least 4 arguments"); const AST* matchee = call->list_ref(1); const AST* retT = tenv.var(); const AST* matcheeT = tenv.var(); resp_constrain(tenv, c, matchee); for (ATuple::const_iterator i = call->iter_at(2); i != call->end();) { const AST* exp = *i++; const ATuple* pattern = exp->to_tuple(); THROW_IF(!pattern, exp->loc, "missing pattern"); THROW_IF(i == call->end(), pattern->loc, "missing expression"); const AST* body = *i++; const ASymbol* name = (*pattern->begin())->to_symbol(); THROW_IF(!name, (*pattern->begin())->loc, "pattern does not start with a symbol"); TEnv::Tags::const_iterator tag = tenv.tags.find(name->str()); THROW_IF(tag == tenv.tags.end(), name->loc, (format("undefined constructor `%1%'") % name->sym()).str()); const TEnv::Constructor& constructor = tag->second; TEnv::Frame frame; ATuple::const_iterator ei = constructor.expr->as_tuple()->iter_at(1); for (ATuple::const_iterator pi = pattern->iter_at(1); pi != pattern->end(); ++pi, ++ei) { const AST* tvar = tenv.var(*pi); frame.push_back(make_pair((*pi)->as_symbol()->sym(), tvar)); } tenv.push(frame); resp_constrain(tenv, c, body); c.constrain(tenv, body, retT); // Copy the type's prototype replacing symbols with real type variables List type(matchee->loc, constructor.type->as_tuple()->fst(), NULL); for (ATuple::const_iterator t = constructor.type->as_tuple()->iter_at(1); t != constructor.type->as_tuple()->end(); ++t) { type.push_back(tenv.var()); } c.constrain(tenv, matchee, type); tenv.pop(); } c.constrain(tenv, call, retT); c.constrain(tenv, matchee, matcheeT); } static void resp_constrain_quoted(TEnv& tenv, Constraints& c, const AST* ast) throw(Error) { if (ast->tag() == T_SYMBOL) { c.constrain(tenv, ast, tenv.named("Symbol")); } else if (ast->tag() == T_TUPLE) { List tupT(new ATuple(tenv.Expr, NULL, ast->loc)); const ATuple* tup = ast->as_tuple(); const AST* fstT = tenv.var(tup->fst()); c.constrain(tenv, ast, tupT); c.constrain(tenv, tup->fst(), fstT); for (const auto& i : *ast->as_tuple()) { resp_constrain_quoted(tenv, c, i); } } else { resp_constrain(tenv, c, ast); } } static void constrain_quote(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { THROW_IF(call->list_len() != 2, call->loc, "`quote' requires exactly 1 argument"); resp_constrain_quoted(tenv, c, call->frst()); List type(call->loc, tenv.Expr, NULL); c.constrain(tenv, call, type); } static void constrain_call(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { const AST* const head = call->fst(); for (auto i : *call) resp_constrain(tenv, c, i); const AST* fnType = tenv.var(head); if (!AType::is_var(fnType)) { if (!is_form(fnType, "Fn")) throw Error(call->loc, (format("call to non-function `%1%'") % head->str()).str()); size_t numArgs = fnType->as_tuple()->prot()->list_len(); THROW_IF(numArgs != call->list_len() - 1, call->loc, (format("expected %1% arguments, got %2%") % numArgs % (call->list_len() - 1)).str()); } const AST* retT = tenv.var(call); List argsT; for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) argsT.push_back(tenv.var(*i)); if (argsT.head) argsT.head->loc = call->loc; c.constrain(tenv, head, tup(head->loc, tenv.Fn, argsT.head, retT, 0)); c.constrain(tenv, call, retT); } static void constrain_primitive(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) { const string n = call->fst()->to_symbol()->str(); enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type; if (n == "+" || n == "-" || n == "*" || n == "/") type = ARITHMETIC; else if (n == "%") type = BINARY; else if (n == "and" || n == "or" || n == "xor") type = LOGICAL; else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=") type = COMPARISON; else throw Error(call->loc, (format("unknown primitive `%1%'") % n).str()); ATuple::const_iterator i = call->begin(); for (++i; i != call->end(); ++i) resp_constrain(tenv, c, *i); i = call->begin(); const AST* var = NULL; switch (type) { case ARITHMETIC: if (call->list_len() < 3) throw Error(call->loc, (format("`%1%' requires at least 2 arguments") % n).str()); for (++i; i != call->end(); ++i) c.constrain(tenv, *i, tenv.var(call)); break; case BINARY: if (call->list_len() != 3) throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, *++i, tenv.var(call)); c.constrain(tenv, *++i, tenv.var(call)); break; case LOGICAL: if (call->list_len() != 3) throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str()); c.constrain(tenv, call, tenv.named("Bool")); c.constrain(tenv, *++i, tenv.named("Bool")); c.constrain(tenv, *++i, tenv.named("Bool")); break; case COMPARISON: if (call->list_len() != 3) throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str()); var = tenv.var(*++i); c.constrain(tenv, call, tenv.named("Bool")); c.constrain(tenv, *++i, var); break; default: throw Error(call->loc, (format("unknown primitive `%1%'") % n).str()); } } static void constrain_list(TEnv& tenv, Constraints& c, const ATuple* tup) throw(Error) { const ASymbol* const sym = tup->fst()->to_symbol(); if (!sym) { constrain_call(tenv, c, tup); return; } const std::string form = sym->sym(); if (is_primitive(tenv.penv, tup)) constrain_primitive(tenv, c, tup); else if (form == "cons" || isupper(form[0])) constrain_cons(tenv, c, tup); else if (form == ".") constrain_dot(tenv, c, tup); else if (form == "define") constrain_def(tenv, c, tup); else if (form == "def-type") constrain_def_type(tenv, c, tup); else if (form == "lambda") constrain_fn(tenv, c, tup); else if (form == "if") constrain_if(tenv, c, tup); else if (form == "let") constrain_let(tenv, c, tup); else if (form == "match") constrain_match(tenv, c, tup); else if (form == "quote") constrain_quote(tenv, c, tup); else constrain_call(tenv, c, tup); } void resp_constrain(TEnv& tenv, Constraints& c, const AST* ast) throw(Error) { switch (ast->tag()) { case T_UNKNOWN: case T_TVAR: break; case T_BOOL: c.constrain(tenv, ast, tenv.named("Bool")); break; case T_FLOAT: c.constrain(tenv, ast, tenv.named("Float")); break; case T_INT32: c.constrain(tenv, ast, tenv.named("Int")); break; case T_STRING: c.constrain(tenv, ast, tenv.named("String")); break; case T_SYMBOL: case T_LITSYM: constrain_symbol(tenv, c, ast->as_symbol()); break; case T_TUPLE: constrain_list(tenv, c, ast->as_tuple()); break; case T_ELLIPSIS: throw Error(ast->loc, "ellipsis present after expand stage"); } }