aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Robillard <d@drobilla.net>2009-01-27 04:17:22 +0000
committerDavid Robillard <d@drobilla.net>2009-01-27 04:17:22 +0000
commit73c2efb353b45e29df71268668e4ad1c8c2fa93c (patch)
tree295dc2ed834965c85f6a00994d61b8343f0ac6e1
parentbeb36974ecaafa30296b90f7f61c24235450c2f3 (diff)
downloadresp-73c2efb353b45e29df71268668e4ad1c8c2fa93c.tar.gz
resp-73c2efb353b45e29df71268668e4ad1c8c2fa93c.tar.bz2
resp-73c2efb353b45e29df71268668e4ad1c8c2fa93c.zip
Real type inference (classic Hindley-Milner straight from TAPL).
git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@24 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r--ll.cpp246
1 files changed, 143 insertions, 103 deletions
diff --git a/ll.cpp b/ll.cpp
index a906bdc..824f8a7 100644
--- a/ll.cpp
+++ b/ll.cpp
@@ -69,22 +69,20 @@ typedef Exp<string> SExp;
static SExp
readExpression(std::istream& in)
{
+#define PUSH(s, t) { if (t != "") { s.top().list.push_back(t); t = ""; } }
+#define YIELD(s, t) { if (s.empty()) return t; else PUSH(s, t) }
stack<SExp> stk;
string tok;
-
-#define PUSH(t) { if (t != "") { stk.top().list.push_back(t); t = ""; } }
-#define YIELD(t) { if (stk.empty()) return t; else PUSH(t) }
-
while (char ch = in.get()) {
switch (ch) {
case EOF:
return SExp();
case ' ': case '\t': case '\n':
- if (tok != "") YIELD(tok);
+ if (tok != "") YIELD(stk, tok);
break;
case '"':
do { tok.push_back(ch); } while ((ch = in.get()) != '"');
- YIELD(tok + '"');
+ YIELD(stk, tok + '"');
break;
case '(':
stk.push(SExp());
@@ -94,10 +92,10 @@ readExpression(std::istream& in)
case 0:
throw SyntaxError("Unexpected ')'");
case 1:
- PUSH(tok);
+ PUSH(stk, tok);
return stk.top();
default:
- PUSH(tok);
+ PUSH(stk, tok);
SExp l = stk.top();
stk.pop();
stk.top().list.push_back(l);
@@ -108,9 +106,9 @@ readExpression(std::istream& in)
}
}
switch (stk.size()) {
- case 0: return tok; break;
- case 1: return stk.top(); break;
- default: throw SyntaxError("Missing ')'");
+ case 0: return tok;
+ case 1: return stk.top();
+ default: throw SyntaxError("Missing ')'");
}
return SExp();
}
@@ -127,11 +125,13 @@ struct AType; ///< Abstract Type
/// Base class for all AST nodes
struct AST {
virtual ~AST() {}
- virtual bool operator==(const AST& rhs) const = 0;
- virtual string str() const = 0;
- virtual void constrain(TEnv& tenv) const {}
- virtual void lift(CEnv& cenv) {}
- virtual Value* compile(CEnv& cenv) = 0;
+ virtual bool contains(AST* child) const { return false; }
+ virtual bool operator!=(const AST& o) const { return !operator==(o); }
+ virtual bool operator==(const AST& o) const = 0;
+ virtual string str() const = 0;
+ virtual void constrain(TEnv& tenv) const {}
+ virtual void lift(CEnv& cenv) {}
+ virtual Value* compile(CEnv& cenv) = 0;
};
/// Literal
@@ -170,17 +170,24 @@ struct ASTTuple : public AST {
return ret + ")";
}
bool operator==(const AST& rhs) const {
+ const ASTTuple* rt = dynamic_cast<const ASTTuple*>(&rhs);
+ if (!rt) return false;
+ if (rt->tup.size() != tup.size()) return false;
TupV::const_iterator l = tup.begin();
- FOREACH(TupV::const_iterator, r, tup)
- if ((*l++) != (*r))
+ FOREACH(TupV::const_iterator, r, rt->tup) {
+ AST* mine = *l++;
+ AST* other = *r;
+ if (!(*mine == *other))
return false;
+ }
return true;
}
- bool operator!=(const ASTTuple& t) const { return ! operator==(t); }
void lift(CEnv& cenv) {
FOREACH(TupV::iterator, t, tup)
(*t)->lift(cenv);
}
+ bool isForm(const string& f) { return !tup.empty() && tup[0]->str() == f; }
+ bool contains(AST* child) const;
void constrain(TEnv& tenv) const;
Value* compile(CEnv& cenv) { return NULL; }
TupV tup;
@@ -223,6 +230,16 @@ struct AType : public ASTTuple {
}
return true;
}
+ bool operator==(const AST& rhs) const {
+ const AType* rt = dynamic_cast<const AType*>(&rhs);
+ if (!rt)
+ return false;
+ else if (var && rt->var)
+ return id == rt->id;
+ else if (!var && !rt->var)
+ return ASTTuple::operator==(rhs);
+ return false;
+ }
bool var;
const Type* ctype;
unsigned id;
@@ -390,21 +407,19 @@ struct Env : public list< map<K,V> > {
struct TypeError : public Error { TypeError (const char* m) : Error(m) {} };
+struct TSubst : public map<AType*, AType*> {
+ TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); }
+};
+
/// Type-Time Environment
struct TEnv {
TEnv(PEnv& p) : penv(p), varID(1) {}
- typedef map<const AST*, AType*> Types;
- typedef multimap<const AST*, AType*> Constraints;
+ typedef map<const AST*, AType*> Types;
+ typedef list< pair<AType*, AType*> > Constraints;
AType* var() { return new AType(varID++); }
AType* type(const AST* ast) {
Types::iterator t = types.find(ast);
- if (t != types.end()) {
- return t->second;
- } else {
- AType* tvar = var();
- constrain(ast, tvar);
- return tvar;
- }
+ return (t != types.end()) ? t->second : (types[ast] = var());
}
AType* named(const string& name) const {
Types::const_iterator i = namedTypes.find(penv.sym(name));
@@ -415,10 +430,12 @@ struct TEnv {
ASTSymbol* sym = penv.sym(name);
namedTypes[sym] = new AType(penv.sym(name), type);
}
- void constrain(const AST* ast, AType* type) {
- constraints.insert(make_pair(ast, type));
+ void constrain(const AST* o, AType* t) {
+ constraints.push_back(make_pair(type(o), t));
}
- AType* unify(AST* root);
+ void solve() { apply(unify(constraints)); }
+ void apply(const TSubst& substs);
+ static TSubst unify(const Constraints& c);
PEnv& penv;
Types types;
Types namedTypes;
@@ -436,7 +453,9 @@ ASTTuple::constrain(TEnv& tenv) const
(*p)->constrain(tenv);
texp.push_back(tenv.type(*p));
}
- tenv.constrain(this, new AType(texp));
+ AType* t = tenv.type(this);
+ t->var = false;
+ t->tup = texp;
}
void
@@ -499,71 +518,96 @@ ASTPrimitive::constrain(TEnv& tenv) const
}
}
-static bool
+static void
substitute(ASTTuple* tup, AST* from, AST* to)
{
- bool progress = false;
- for (size_t i = 0; i < tup->tup.size(); ++i) {
- if (*tup->tup[i] == *from) {
+ if (!tup) return;
+ for (size_t i = 0; i < tup->tup.size(); ++i)
+ if (*tup->tup[i] == *from)
tup->tup[i] = to;
- progress = true;
- }
+ else
+ substitute(dynamic_cast<ASTTuple*>(tup->tup[i]), from, to);
+}
+
+bool
+ASTTuple::contains(AST* child) const
+{
+ if (*this == *child) return true;
+ FOREACH(TupV::const_iterator, p, tup)
+ if (**p == *child || (*p)->contains(child))
+ return true;
+ return false;
+}
+
+TSubst
+compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1
+{
+ TSubst r;
+ for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
+ TSubst::const_iterator i = delta.find(g->second);
+ if (i != delta.end())
+ r.insert(make_pair(g->first, ((i != delta.end()) ? i : g)->second));
+ else
+ r.insert(make_pair(g->first, g->second));
+ }
+ for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
+ if (gamma.find(d->first) == gamma.end())
+ r.insert(make_pair(d->first, d->second));
}
- return progress;
+ return r;
}
-AType*
-TEnv::unify(AST* root)
+void
+substConstraints(TEnv::Constraints& constraints, AType* s, AType* t)
{
- root->constrain(*this);
- assert(constraints.find(root) != constraints.end());
- //constrain(root, var());
- typedef map<AType*, AType*> Substitutions;
- Substitutions subst;
- for (bool progress = true; progress; progress = false) {
- //std::cout << "==== " << constraints.size() << endl;
- for (Constraints::iterator c = constraints.begin(); c != constraints.end();) {
- Constraints::iterator next = c;
- ++next;
- const AST* o = c->first;
- AType* t = c->second;
- //std::cout << "Constr : " << o->str() << " = " << t->str() << endl;
- if (t->concrete()) {
- Types::iterator ot = types.find(o);
- if (ot == types.end()) {
- //std::cout << "Resolv : " << o->str() << endl;
- types.insert(make_pair(o, t));
- //constraints.erase(c);
- progress = true;
- }
- } else {
- Types::iterator ot = types.find(o);
- if (ot != types.end()) {
- subst[t] = ot->second;
- progress = true;
- }
- }
- c = next;
- }
- for (Substitutions::iterator s = subst.begin(); s != subst.end(); ++s) {
- //std::cout << "Subst : " << s->first->str() << " => " << s->second->str() << endl;
- for (Constraints::iterator c = constraints.begin(); c != constraints.end(); ++c) {
- AType* objT = c->second;
- if (objT == s->first && objT != s->second) {
- c->second = s->second;
- progress = true;
- }
- progress = substitute(c->second, s->first, s->second) || progress;
- }
- }
+ for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) {
+ TEnv::Constraints::iterator next = c; ++next;
+ if (*c->first == *s) c->first = t;
+ if (*c->second == *s) c->second = t;
+ substitute(c->first, s, t);
+ substitute(c->second, s, t);
+ c = next;
}
- //std::cout << "======== Done Unifying Types " << constraints.size() << endl;
- Constraints::iterator i = constraints.find(root);
- if (i != constraints.end()) {
- types[root] = i->second;
- return i->second;
+}
+
+TSubst
+TEnv::unify(const Constraints& constraints) // TAPL 22.4
+{
+ if (constraints.empty()) return TSubst();
+ AType* s = constraints.begin()->first;
+ AType* t = constraints.begin()->second;
+ Constraints cp = constraints;
+ cp.erase(cp.begin());
+
+ if (*s == *t) {
+ return unify(cp);
+ } else if (s->var && !t->contains(s)) {
+ substConstraints(cp, s, t);
+ return compose(unify(cp), TSubst(s, t));
+ } else if (t->var && !s->contains(t)) {
+ substConstraints(cp, t, s);
+ return compose(unify(cp), TSubst(t, s));
+ } else if (s->isForm("Fn") && t->isForm("Fn")) {
+ AType* s1 = dynamic_cast<AType*>(s->tup[1]);
+ AType* t1 = dynamic_cast<AType*>(t->tup[1]);
+ AType* s2 = dynamic_cast<AType*>(s->tup[2]);
+ AType* t2 = dynamic_cast<AType*>(t->tup[2]);
+ assert(s1 && t1 && s2 && t2);
+ cp.push_back(make_pair(s1, t1));
+ cp.push_back(make_pair(s2, t2));
+ return unify(cp);
+ } else {
+ throw TypeError("Type unification failed");
}
- return NULL;
+}
+
+void
+TEnv::apply(const TSubst& substs)
+{
+ FOREACH(TSubst::const_iterator, s, substs)
+ FOREACH(Types::iterator, t, types)
+ if (*t->second == *s->first)
+ t->second = s->second;
}
@@ -648,8 +692,7 @@ Value*
ASTSymbol::compile(CEnv& cenv)
{
Value** v = cenv.vals.ref(this);
- if (v)
- return *v;
+ if (v) return *v;
AST** c = cenv.code.ref(this);
if (c) {
@@ -879,16 +922,18 @@ main()
cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false));
while (1) {
- std::cout << "(= ";
+ std::cout << "() ";
std::cout.flush();
SExp exp = readExpression(std::cin);
if (exp.type == SExp::LIST && exp.list.empty())
break;
try {
- AST* body = parseExpression(penv, exp);
- AType* bodyT = cenv.tenv.unify(body);
+ AST* body = parseExpression(penv, exp); // Parse input
+ body->constrain(cenv.tenv); // Constrain types
+ cenv.tenv.solve(); // Solve and apply type constraints
+ AType* bodyT = cenv.tenv.type(body);
if (!bodyT) throw TypeError("REPL call to untyped body");
if (bodyT->var) throw TypeError("REPL call to variable typed body");
@@ -900,7 +945,6 @@ main()
Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype);
BasicBlock* bb = BasicBlock::Create("entry", f);
cenv.builder.SetInsertPoint(bb);
-
try {
Value* retVal = body->compile(cenv);
cenv.builder.CreateRet(retVal); // Finish function
@@ -910,22 +954,18 @@ main()
f->eraseFromParent(); // Error reading body, remove function
throw e;
}
-
void* fp = engine->getPointerToFunction(f);
if (bodyT->ctype == Type::Int32Ty)
- std::cout << " " << ((int32_t (*)())fp)();
+ std::cout << "; " << ((int32_t (*)())fp)();
else if (bodyT->ctype == Type::FloatTy)
- std::cout << " " << ((float (*)())fp)();
+ std::cout << "; " << ((float (*)())fp)();
else if (bodyT->ctype == Type::Int1Ty)
- std::cout << " " << ((bool (*)())fp)();
- else
- std::cout << " ?";
-
+ std::cout << "; " << ((bool (*)())fp)();
} else {
Value* val = body->compile(cenv);
- std::cout << " " << val;
+ std::cout << "; " << val;
}
- std::cout << " : " << cenv.tenv.type(body)->str() << ")" << endl;
+ std::cout << " : " << cenv.tenv.type(body)->str() << endl;
} catch (Error e) {
std::cerr << "Error: " << e.what() << endl;