aboutsummaryrefslogtreecommitdiffstats
path: root/src/constrain.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/constrain.cpp')
-rw-r--r--src/constrain.cpp224
1 files changed, 224 insertions, 0 deletions
diff --git a/src/constrain.cpp b/src/constrain.cpp
new file mode 100644
index 0000000..a1868c4
--- /dev/null
+++ b/src/constrain.cpp
@@ -0,0 +1,224 @@
+/* Tuplr Type Inferencing
+ * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
+ *
+ * Tuplr is free software: you can redistribute it and/or modify it under
+ * the terms of the GNU Affero General Public License as published by the
+ * Free Software Foundation, either version 3 of the License, or (at your
+ * option) any later version.
+ *
+ * Tuplr is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
+ * Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Tuplr. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+/** @file
+ * @brief Constrain type of AST expressions
+ */
+
+#include <set>
+#include "tuplr.hpp"
+
+void
+ASymbol::constrain(TEnv& tenv, Constraints& c) const
+{
+ addr = tenv.lookup(this);
+ if (!addr)
+ throw Error(loc, (format("undefined symbol `%1%'") % cppstr).str());
+ c.push_back(Constraint(tenv.var(this), tenv.deref(addr).second, loc));
+}
+
+void
+ATuple::constrain(TEnv& tenv, Constraints& c) const
+{
+ AType* t = tup<AType>(loc, NULL);
+ FOREACH(const_iterator, p, *this) {
+ (*p)->constrain(tenv, c);
+ t->push_back(tenv.var(*p));
+ }
+ c.push_back(Constraint(tenv.var(this), t, loc));
+}
+
+void
+AFn::constrain(TEnv& tenv, Constraints& c) const
+{
+ const AType* genericType;
+ TEnv::GenericTypes::const_iterator gt = tenv.genericTypes.find(this);
+ if (gt != tenv.genericTypes.end()) {
+ genericType = gt->second;
+ } else {
+ set<ASymbol*> defined;
+ TEnv::Frame frame;
+
+ // Add parameters to environment frame
+ for (size_t i = 0; i < prot()->size(); ++i) {
+ ASymbol* sym = prot()->at(i)->to<ASymbol*>();
+ if (!sym)
+ throw Error(prot()->at(i)->loc, "parameter name is not a symbol");
+ if (defined.find(sym) != defined.end())
+ throw Error(sym->loc, (format("duplicate parameter `%1%'") % sym->str()).str());
+ defined.insert(sym);
+ frame.push_back(make_pair(sym, make_pair((AST*)NULL, (AType*)NULL)));
+ }
+
+ // Add internal definitions to environment frame
+ size_t e = 2;
+ for (; e < size(); ++e) {
+ AST* exp = at(e);
+ ADef* def = exp->to<ADef*>();
+ if (def) {
+ ASymbol* sym = def->sym();
+ if (defined.find(sym) != defined.end())
+ throw Error(def->loc, (format("`%1%' defined twice") % sym->str()).str());
+ defined.insert(def->sym());
+ frame.push_back(make_pair(def->sym(), make_pair(def->at(2), (AType*)NULL)));
+ }
+ }
+
+ tenv.push(frame);
+
+ Constraints cp;
+ cp.push_back(Constraint(tenv.var(this), tenv.var(), loc));
+
+ AType* protT = tup<AType>(loc, NULL);
+ for (size_t i = 0; i < prot()->size(); ++i) {
+ AType* tvar = tenv.fresh(prot()->at(i)->to<ASymbol*>());
+ protT->push_back(tvar);
+ assert(frame[i].first == prot()->at(i));
+ frame[i].second.first = prot()->at(i);
+ frame[i].second.second = tvar;
+ }
+ c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc));
+
+ for (size_t i = 2; i < size(); ++i)
+ at(i)->constrain(tenv, cp);
+
+ AType* bodyT = tenv.var(at(e-1));
+ Subst tsubst = TEnv::unify(cp);
+ genericType = tup<AType>(loc, tenv.penv.sym("Fn"),
+ tsubst.apply(protT), tsubst.apply(bodyT), 0);
+ tenv.genericTypes.insert(make_pair(this, genericType));
+ Object::pool.addRoot(genericType);
+
+ tenv.pop();
+ subst = tsubst;
+ }
+
+ AType* t = new AType(*genericType); // FIXME: deep copy
+ c.constrain(tenv, this, t);
+}
+
+void
+ACall::constrain(TEnv& tenv, Constraints& c) const
+{
+ at(0)->constrain(tenv, c);
+ for (size_t i = 1; i < size(); ++i)
+ at(i)->constrain(tenv, c);
+
+ AST* callee = tenv.resolve(at(0));
+ AFn* closure = callee->to<AFn*>();
+ if (closure) {
+ if (size() - 1 != closure->prot()->size())
+ throw Error(loc, "incorrect number of arguments");
+ TEnv::GenericTypes::iterator gt = tenv.genericTypes.find(closure);
+ if (gt != tenv.genericTypes.end()) {
+ for (size_t i = 1; i < size(); ++i)
+ c.constrain(tenv, at(i), gt->second->at(1)->as<ATuple*>()->at(i-1)->as<AType*>());
+ AType* retT = tenv.var(this);
+ c.constrain(tenv, at(0), tup<AType>(at(0)->loc, tenv.penv.sym("Fn"), tenv.var(), retT, 0));
+ c.constrain(tenv, this, retT);
+ return;
+ }
+ }
+ AType* argsT = tup<AType>(loc, 0);
+ for (size_t i = 1; i < size(); ++i)
+ argsT->push_back(tenv.var(at(i)));
+ AType* retT = tenv.var();
+ c.constrain(tenv, at(0), tup<AType>(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
+ c.constrain(tenv, this, retT);
+}
+
+void
+ADef::constrain(TEnv& tenv, Constraints& c) const
+{
+ THROW_IF(size() != 3, loc, "`def' requires exactly 2 arguments");
+ const ASymbol* sym = this->sym();
+ THROW_IF(!sym, loc, "`def' has no symbol")
+
+ AType* tvar = tenv.var(at(2));
+ tenv.def(sym, make_pair(at(2), tvar));
+ at(2)->constrain(tenv, c);
+ c.constrain(tenv, this, tvar);
+}
+
+void
+AIf::constrain(TEnv& tenv, Constraints& c) const
+{
+ THROW_IF(size() < 4, loc, "`if' requires at least 3 arguments");
+ THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause")
+ for (size_t i = 1; i < size(); ++i)
+ at(i)->constrain(tenv, c);
+ AType* retT = tenv.var(this);
+ for (size_t i = 1; i < size(); i += 2) {
+ if (i == size() - 1) {
+ c.constrain(tenv, at(i), retT);
+ } else {
+ c.constrain(tenv, at(i), tenv.named("Bool"));
+ c.constrain(tenv, at(i+1), retT);
+ }
+ }
+}
+
+void
+APrimitive::constrain(TEnv& tenv, Constraints& c) const
+{
+ const string n = at(0)->to<ASymbol*>()->str();
+ enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type;
+ if (n == "+" || n == "-" || n == "*" || n == "/")
+ type = ARITHMETIC;
+ else if (n == "%")
+ type = BINARY;
+ else if (n == "and" || n == "or" || n == "xor")
+ type = LOGICAL;
+ else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=")
+ type = COMPARISON;
+ else
+ throw Error(loc, (format("unknown primitive `%1%'") % n).str());
+
+ for (size_t i = 1; i < size(); ++i)
+ at(i)->constrain(tenv, c);
+
+ switch (type) {
+ case ARITHMETIC:
+ if (size() < 3)
+ throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str());
+ for (size_t i = 1; i < size(); ++i)
+ c.constrain(tenv, at(i), tenv.var(this));
+ break;
+ case BINARY:
+ if (size() != 3)
+ throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
+ c.constrain(tenv, at(1), tenv.var(this));
+ c.constrain(tenv, at(2), tenv.var(this));
+ break;
+ case LOGICAL:
+ if (size() != 3)
+ throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
+ c.constrain(tenv, this, tenv.named("Bool"));
+ c.constrain(tenv, at(1), tenv.named("Bool"));
+ c.constrain(tenv, at(2), tenv.named("Bool"));
+ break;
+ case COMPARISON:
+ if (size() != 3)
+ throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
+ c.constrain(tenv, this, tenv.named("Bool"));
+ c.constrain(tenv, at(1), tenv.var(at(2)));
+ break;
+ default:
+ throw Error(loc, (format("unknown primitive `%1%'") % n).str());
+ }
+}
+