aboutsummaryrefslogtreecommitdiffstats
path: root/typing.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'typing.cpp')
-rw-r--r--typing.cpp231
1 files changed, 150 insertions, 81 deletions
diff --git a/typing.cpp b/typing.cpp
index fdea255..ef28409 100644
--- a/typing.cpp
+++ b/typing.cpp
@@ -15,78 +15,157 @@
* along with Tuplr. If not, see <http://www.gnu.org/licenses/>.
*/
+#include <set>
#include "tuplr.hpp"
+void
+Constraints::constrain(TEnv& tenv, const AST* o, AType* t)
+{
+ assert(!dynamic_cast<const AType*>(o));
+ push_back(Constraint(tenv.var(o), t, o->loc));
+}
+
/***************************************************************************
* AST Type Constraints *
***************************************************************************/
void
-ATuple::constrain(TEnv& tenv) const
+ASymbol::lookup(TEnv& tenv)
+{
+ addr = tenv.lookup(this);
+ if (!addr)
+ throw Error((format("undefined symbol `%1%'") % cppstr).str(), loc);
+}
+
+void
+ASymbol::constrain(TEnv& tenv, Constraints& c)
+{
+ lookup(tenv);
+ AType* t = tenv.deref(addr);
+ if (!t)
+ throw Error((format("unresolved symbol `%1%'") % cppstr).str(), loc);
+ c.push_back(Constraint(tenv.var(this), tenv.deref(addr), loc));
+}
+
+void
+ATuple::constrain(TEnv& tenv, Constraints& c)
{
AType* t = new AType(ATuple(), loc);
- FOREACH(const_iterator, p, *this) {
- (*p)->constrain(tenv);
- t->push_back(tenv.type(*p));
+ FOREACH(iterator, p, *this) {
+ (*p)->constrain(tenv, c);
+ t->push_back(tenv.var(*p));
}
- tenv.constrain(this, t);
+ c.push_back(Constraint(tenv.var(this), t, loc));
}
void
-AClosure::constrain(TEnv& tenv) const
+AClosure::constrain(TEnv& tenv, Constraints& c)
{
- at(1)->constrain(tenv);
- at(2)->constrain(tenv);
- AType* protT = tenv.type(at(1));
- AType* bodyT = tenv.type(at(2));
- tenv.constrain(this, new AType(loc, tenv.penv.sym("Fn"), protT, bodyT, 0));
+ set<ASymbol*> defined;
+ TEnv::Frame frame;
+
+ // Add parameters to environment frame
+ for (size_t i = 0; i < prot()->size(); ++i) {
+ ASymbol* sym = dynamic_cast<ASymbol*>(prot()->at(i));
+ if (!sym)
+ throw Error("parameter name is not a symbol", prot()->at(i)->loc);
+ if (defined.find(sym) != defined.end())
+ throw Error((format("duplicate parameter `%1%'") % sym->str()).str(), sym->loc);
+ defined.insert(sym);
+ frame.push_back(make_pair(sym, (AType*)NULL));
+ }
+
+ // Add internal definitions to environment frame
+ size_t e = 2;
+ for (; e < size(); ++e) {
+ AST* exp = at(e);
+ ADefinition* def = dynamic_cast<ADefinition*>(exp);
+ if (def) {
+ ASymbol* sym = def->sym();
+ if (defined.find(sym) != defined.end())
+ throw Error((format("`%1%' defined twice") % sym->str()).str(), def->loc);
+ defined.insert(def->sym());
+ frame.push_back(make_pair(def->sym(), (AType*)NULL));
+ }
+ }
+
+ tenv.push(frame);
+
+ Constraints cp;
+ cp.push_back(Constraint(tenv.var(this), tenv.var(), loc));
+
+ AType* protT = new AType(ATuple(), loc);
+ for (size_t i = 0; i < prot()->size(); ++i) {
+ AType* tvar = tenv.fresh(dynamic_cast<ASymbol*>(prot()->at(i)));
+ protT->push_back(tvar);
+ assert(frame[i].first == prot()->at(i));
+ frame[i].second = tvar;
+ }
+ c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc));
+
+ for (size_t i = 2; i < size(); ++i)
+ at(i)->constrain(tenv, cp);
+
+ AType* bodyT = tenv.var(at(e-1));
+ Subst tsubst = TEnv::unify(cp);
+ type = new AType(loc, tenv.penv.sym("Fn"), tsubst.apply(protT), tsubst.apply(bodyT), 0);
+
+ tenv.pop();
+
+ c.constrain(tenv, this, type);
+ subst = new Subst(tsubst);
}
void
-ACall::constrain(TEnv& tenv) const
+ACall::constrain(TEnv& tenv, Constraints& c)
{
- FOREACH(const_iterator, p, *this)
- (*p)->constrain(tenv);
- AType* retT = tenv.type(this);
AType* argsT = new AType(ATuple(), loc);
- for (size_t i = 1; i < size(); ++i)
- argsT->push_back(tenv.type(at(i)));
- tenv.constrain(at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
+ TEnv::Frame frame;
+ for (size_t i = 1; i < size(); ++i) {
+ at(i)->constrain(tenv, c);
+ argsT->push_back(tenv.var(at(i)));
+ frame.push_back(make_pair((AST*)NULL, tenv.var(at(i))));
+ }
+ AType* retT = tenv.var();
+
+ at(0)->constrain(tenv, c);
+
+ c.constrain(tenv, at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
+ c.constrain(tenv, this, retT);
}
void
-ADefinition::constrain(TEnv& tenv) const
+ADefinition::constrain(TEnv& tenv, Constraints& c)
{
if (size() != 3) throw Error("`def' requires exactly 2 arguments", loc);
if (!dynamic_cast<const ASymbol*>(at(1)))
throw Error("`def' name is not a symbol", loc);
- FOREACH(const_iterator, p, *this)
- (*p)->constrain(tenv);
- AType* tvar = tenv.type(this);
- tenv.constrain(at(1), tvar);
- tenv.constrain(at(2), tvar);
+ AType* tvar = tenv.var(this);
+ tenv.def(at(1), tvar);
+ at(2)->constrain(tenv, c);
+ c.constrain(tenv, at(2), tvar);
}
void
-AIf::constrain(TEnv& tenv) const
+AIf::constrain(TEnv& tenv, Constraints& c)
{
if (size() < 3) throw Error("`if' requires exactly 3 arguments", loc);
if (size() % 2 != 0) throw Error("`if' missing final else clause", loc);
- FOREACH(const_iterator, p, *this)
- (*p)->constrain(tenv);
- AType* retT = tenv.type(this);
+ for (size_t i = 1; i < size(); ++i)
+ at(i)->constrain(tenv, c);
+ AType* retT = tenv.var(this);
for (size_t i = 1; i < size(); i += 2) {
if (i == size() - 1) {
- tenv.constrain(at(i), retT);
+ c.constrain(tenv, at(i), retT);
} else {
- tenv.constrain(at(i), tenv.named("Bool"));
- tenv.constrain(at(i+1), retT);
+ c.constrain(tenv, at(i), tenv.named("Bool"));
+ c.constrain(tenv, at(i+1), retT);
}
}
}
void
-APrimitive::constrain(TEnv& tenv) const
+APrimitive::constrain(TEnv& tenv, Constraints& c)
{
const string n = dynamic_cast<ASymbol*>(at(0))->str();
enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type;
@@ -101,34 +180,34 @@ APrimitive::constrain(TEnv& tenv) const
else
throw Error((format("unknown primitive `%1%'") % n).str(), loc);
- FOREACH(const_iterator, p, *this)
- (*p)->constrain(tenv);
+ for (size_t i = 1; i < size(); ++i)
+ at(i)->constrain(tenv, c);
switch (type) {
case ARITHMETIC:
if (size() < 3)
throw Error((format("`%1%' requires at least 2 arguments") % n).str(), loc);
for (size_t i = 1; i < size(); ++i)
- tenv.constrain(at(i), tenv.type(this));
+ c.constrain(tenv, at(i), tenv.var(this));
break;
case BINARY:
if (size() != 3)
throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc);
- tenv.constrain(at(1), tenv.type(this));
- tenv.constrain(at(2), tenv.type(this));
+ c.constrain(tenv, at(1), tenv.var(this));
+ c.constrain(tenv, at(2), tenv.var(this));
break;
case LOGICAL:
if (size() != 3)
throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc);
- tenv.constrain(this, tenv.named("Bool"));
- tenv.constrain(at(1), tenv.named("Bool"));
- tenv.constrain(at(2), tenv.named("Bool"));
+ c.constrain(tenv, this, tenv.named("Bool"));
+ c.constrain(tenv, at(1), tenv.named("Bool"));
+ c.constrain(tenv, at(2), tenv.named("Bool"));
break;
case COMPARISON:
if (size() != 3)
throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc);
- tenv.constrain(this, tenv.named("Bool"));
- tenv.constrain(at(1), tenv.type(at(2)));
+ c.constrain(tenv, this, tenv.named("Bool"));
+ c.constrain(tenv, at(1), tenv.var(at(2)));
break;
default:
throw Error((format("unknown primitive `%1%'") % n).str(), loc);
@@ -136,37 +215,37 @@ APrimitive::constrain(TEnv& tenv) const
}
void
-AConsCall::constrain(TEnv& tenv) const
+AConsCall::constrain(TEnv& tenv, Constraints& c)
{
if (size() != 3) throw Error("`cons' requires exactly 2 arguments", loc);
AType* t = new AType(loc, tenv.penv.sym("Pair"), 0);
for (size_t i = 1; i < size(); ++i) {
- at(i)->constrain(tenv);
- t->push_back(tenv.type(at(i)));
+ at(i)->constrain(tenv, c);
+ t->push_back(tenv.var(at(i)));
}
- tenv.constrain(this, t);
+ c.constrain(tenv, this, t);
}
void
-ACarCall::constrain(TEnv& tenv) const
+ACarCall::constrain(TEnv& tenv, Constraints& c)
{
if (size() != 2) throw Error("`car' requires exactly 1 argument", loc);
- at(1)->constrain(tenv);
- AType* carT = tenv.type(this);
+ at(1)->constrain(tenv, c);
+ AType* carT = tenv.var(this);
AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), carT, tenv.var(), 0);
- tenv.constrain(at(1), pairT);
- tenv.constrain(this, carT);
+ c.constrain(tenv, at(1), pairT);
+ c.constrain(tenv, this, carT);
}
void
-ACdrCall::constrain(TEnv& tenv) const
+ACdrCall::constrain(TEnv& tenv, Constraints& c)
{
if (size() != 2) throw Error("`cdr' requires exactly 1 argument", loc);
- at(1)->constrain(tenv);
- AType* cdrT = tenv.type(this);
+ at(1)->constrain(tenv, c);
+ AType* cdrT = tenv.var(this);
AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), tenv.var(), cdrT, 0);
- tenv.constrain(at(1), pairT);
- tenv.constrain(this, cdrT);
+ c.constrain(tenv, at(1), pairT);
+ c.constrain(tenv, this, cdrT);
}
@@ -175,25 +254,26 @@ ACdrCall::constrain(TEnv& tenv) const
***************************************************************************/
static void
-substitute(ATuple* tup, AST* from, AST* to)
+substitute(ATuple* tup, const AST* from, AST* to)
{
if (!tup) return;
for (size_t i = 0; i < tup->size(); ++i)
if (*tup->at(i) == *from)
tup->at(i) = to;
- else
+ else if (tup->at(i) != to)
substitute(dynamic_cast<ATuple*>(tup->at(i)), from, to);
}
-TEnv::Subst
-compose(const TEnv::Subst& delta, const TEnv::Subst& gamma) // TAPL 22.1.1
+
+Subst
+Subst::compose(const Subst& delta, const Subst& gamma) // TAPL 22.1.1
{
- TEnv::Subst r;
- for (TEnv::Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
- TEnv::Subst::const_iterator d = delta.find(g->second);
+ Subst r;
+ for (Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
+ Subst::const_iterator d = delta.find(g->second);
r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second));
}
- for (TEnv::Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
+ for (Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
if (gamma.find(d->first) == gamma.end())
r.insert(*d);
}
@@ -201,10 +281,10 @@ compose(const TEnv::Subst& delta, const TEnv::Subst& gamma) // TAPL 22.1.1
}
void
-substConstraints(TEnv::Constraints& constraints, AType* s, AType* t)
+substConstraints(Constraints& constraints, AType* s, AType* t)
{
- for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) {
- TEnv::Constraints::iterator next = c; ++next;
+ for (Constraints::iterator c = constraints.begin(); c != constraints.end();) {
+ Constraints::iterator next = c; ++next;
if (*c->first == *s) c->first = t;
if (*c->second == *s) c->second = t;
substitute(c->first, s, t);
@@ -213,7 +293,7 @@ substConstraints(TEnv::Constraints& constraints, AType* s, AType* t)
}
}
-TEnv::Subst
+Subst
TEnv::unify(const Constraints& constraints) // TAPL 22.4
{
if (constraints.empty()) return Subst();
@@ -226,10 +306,10 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4
return unify(cp);
} else if (s->var() && !t->contains(s)) {
substConstraints(cp, s, t);
- return compose(unify(cp), Subst(s, t));
+ return Subst::compose(unify(cp), Subst(s, t));
} else if (t->var() && !s->contains(t)) {
substConstraints(cp, t, s);
- return compose(unify(cp), Subst(t, s));
+ return Subst::compose(unify(cp), Subst(t, s));
} else if (s->kind == AType::EXPR && s->kind == t->kind && s->size() == t->size()) {
for (size_t i = 0; i < s->size(); ++i) {
AType* si = dynamic_cast<AType*>(s->at(i));
@@ -244,14 +324,3 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4
}
}
-void
-TEnv::apply(const TEnv::Subst& substs)
-{
- FOREACH(Subst::const_iterator, s, substs)
- FOREACH(Frame::iterator, t, front())
- if (*t->second == *s->first)
- t->second = s->second;
- else
- substitute(t->second, s->first, s->second);
-}
-