diff options
author | David Robillard <d@drobilla.net> | 2009-01-27 04:17:22 +0000 |
---|---|---|
committer | David Robillard <d@drobilla.net> | 2009-01-27 04:17:22 +0000 |
commit | 73c2efb353b45e29df71268668e4ad1c8c2fa93c (patch) | |
tree | 295dc2ed834965c85f6a00994d61b8343f0ac6e1 | |
parent | beb36974ecaafa30296b90f7f61c24235450c2f3 (diff) | |
download | resp-73c2efb353b45e29df71268668e4ad1c8c2fa93c.tar.gz resp-73c2efb353b45e29df71268668e4ad1c8c2fa93c.tar.bz2 resp-73c2efb353b45e29df71268668e4ad1c8c2fa93c.zip |
Real type inference (classic Hindley-Milner straight from TAPL).
git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@24 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r-- | ll.cpp | 246 |
1 files changed, 143 insertions, 103 deletions
@@ -69,22 +69,20 @@ typedef Exp<string> SExp; static SExp readExpression(std::istream& in) { +#define PUSH(s, t) { if (t != "") { s.top().list.push_back(t); t = ""; } } +#define YIELD(s, t) { if (s.empty()) return t; else PUSH(s, t) } stack<SExp> stk; string tok; - -#define PUSH(t) { if (t != "") { stk.top().list.push_back(t); t = ""; } } -#define YIELD(t) { if (stk.empty()) return t; else PUSH(t) } - while (char ch = in.get()) { switch (ch) { case EOF: return SExp(); case ' ': case '\t': case '\n': - if (tok != "") YIELD(tok); + if (tok != "") YIELD(stk, tok); break; case '"': do { tok.push_back(ch); } while ((ch = in.get()) != '"'); - YIELD(tok + '"'); + YIELD(stk, tok + '"'); break; case '(': stk.push(SExp()); @@ -94,10 +92,10 @@ readExpression(std::istream& in) case 0: throw SyntaxError("Unexpected ')'"); case 1: - PUSH(tok); + PUSH(stk, tok); return stk.top(); default: - PUSH(tok); + PUSH(stk, tok); SExp l = stk.top(); stk.pop(); stk.top().list.push_back(l); @@ -108,9 +106,9 @@ readExpression(std::istream& in) } } switch (stk.size()) { - case 0: return tok; break; - case 1: return stk.top(); break; - default: throw SyntaxError("Missing ')'"); + case 0: return tok; + case 1: return stk.top(); + default: throw SyntaxError("Missing ')'"); } return SExp(); } @@ -127,11 +125,13 @@ struct AType; ///< Abstract Type /// Base class for all AST nodes struct AST { virtual ~AST() {} - virtual bool operator==(const AST& rhs) const = 0; - virtual string str() const = 0; - virtual void constrain(TEnv& tenv) const {} - virtual void lift(CEnv& cenv) {} - virtual Value* compile(CEnv& cenv) = 0; + virtual bool contains(AST* child) const { return false; } + virtual bool operator!=(const AST& o) const { return !operator==(o); } + virtual bool operator==(const AST& o) const = 0; + virtual string str() const = 0; + virtual void constrain(TEnv& tenv) const {} + virtual void lift(CEnv& cenv) {} + virtual Value* compile(CEnv& cenv) = 0; }; /// Literal @@ -170,17 +170,24 @@ struct ASTTuple : public AST { return ret + ")"; } bool operator==(const AST& rhs) const { + const ASTTuple* rt = dynamic_cast<const ASTTuple*>(&rhs); + if (!rt) return false; + if (rt->tup.size() != tup.size()) return false; TupV::const_iterator l = tup.begin(); - FOREACH(TupV::const_iterator, r, tup) - if ((*l++) != (*r)) + FOREACH(TupV::const_iterator, r, rt->tup) { + AST* mine = *l++; + AST* other = *r; + if (!(*mine == *other)) return false; + } return true; } - bool operator!=(const ASTTuple& t) const { return ! operator==(t); } void lift(CEnv& cenv) { FOREACH(TupV::iterator, t, tup) (*t)->lift(cenv); } + bool isForm(const string& f) { return !tup.empty() && tup[0]->str() == f; } + bool contains(AST* child) const; void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv) { return NULL; } TupV tup; @@ -223,6 +230,16 @@ struct AType : public ASTTuple { } return true; } + bool operator==(const AST& rhs) const { + const AType* rt = dynamic_cast<const AType*>(&rhs); + if (!rt) + return false; + else if (var && rt->var) + return id == rt->id; + else if (!var && !rt->var) + return ASTTuple::operator==(rhs); + return false; + } bool var; const Type* ctype; unsigned id; @@ -390,21 +407,19 @@ struct Env : public list< map<K,V> > { struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; +struct TSubst : public map<AType*, AType*> { + TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); } +}; + /// Type-Time Environment struct TEnv { TEnv(PEnv& p) : penv(p), varID(1) {} - typedef map<const AST*, AType*> Types; - typedef multimap<const AST*, AType*> Constraints; + typedef map<const AST*, AType*> Types; + typedef list< pair<AType*, AType*> > Constraints; AType* var() { return new AType(varID++); } AType* type(const AST* ast) { Types::iterator t = types.find(ast); - if (t != types.end()) { - return t->second; - } else { - AType* tvar = var(); - constrain(ast, tvar); - return tvar; - } + return (t != types.end()) ? t->second : (types[ast] = var()); } AType* named(const string& name) const { Types::const_iterator i = namedTypes.find(penv.sym(name)); @@ -415,10 +430,12 @@ struct TEnv { ASTSymbol* sym = penv.sym(name); namedTypes[sym] = new AType(penv.sym(name), type); } - void constrain(const AST* ast, AType* type) { - constraints.insert(make_pair(ast, type)); + void constrain(const AST* o, AType* t) { + constraints.push_back(make_pair(type(o), t)); } - AType* unify(AST* root); + void solve() { apply(unify(constraints)); } + void apply(const TSubst& substs); + static TSubst unify(const Constraints& c); PEnv& penv; Types types; Types namedTypes; @@ -436,7 +453,9 @@ ASTTuple::constrain(TEnv& tenv) const (*p)->constrain(tenv); texp.push_back(tenv.type(*p)); } - tenv.constrain(this, new AType(texp)); + AType* t = tenv.type(this); + t->var = false; + t->tup = texp; } void @@ -499,71 +518,96 @@ ASTPrimitive::constrain(TEnv& tenv) const } } -static bool +static void substitute(ASTTuple* tup, AST* from, AST* to) { - bool progress = false; - for (size_t i = 0; i < tup->tup.size(); ++i) { - if (*tup->tup[i] == *from) { + if (!tup) return; + for (size_t i = 0; i < tup->tup.size(); ++i) + if (*tup->tup[i] == *from) tup->tup[i] = to; - progress = true; - } + else + substitute(dynamic_cast<ASTTuple*>(tup->tup[i]), from, to); +} + +bool +ASTTuple::contains(AST* child) const +{ + if (*this == *child) return true; + FOREACH(TupV::const_iterator, p, tup) + if (**p == *child || (*p)->contains(child)) + return true; + return false; +} + +TSubst +compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1 +{ + TSubst r; + for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) { + TSubst::const_iterator i = delta.find(g->second); + if (i != delta.end()) + r.insert(make_pair(g->first, ((i != delta.end()) ? i : g)->second)); + else + r.insert(make_pair(g->first, g->second)); + } + for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) { + if (gamma.find(d->first) == gamma.end()) + r.insert(make_pair(d->first, d->second)); } - return progress; + return r; } -AType* -TEnv::unify(AST* root) +void +substConstraints(TEnv::Constraints& constraints, AType* s, AType* t) { - root->constrain(*this); - assert(constraints.find(root) != constraints.end()); - //constrain(root, var()); - typedef map<AType*, AType*> Substitutions; - Substitutions subst; - for (bool progress = true; progress; progress = false) { - //std::cout << "==== " << constraints.size() << endl; - for (Constraints::iterator c = constraints.begin(); c != constraints.end();) { - Constraints::iterator next = c; - ++next; - const AST* o = c->first; - AType* t = c->second; - //std::cout << "Constr : " << o->str() << " = " << t->str() << endl; - if (t->concrete()) { - Types::iterator ot = types.find(o); - if (ot == types.end()) { - //std::cout << "Resolv : " << o->str() << endl; - types.insert(make_pair(o, t)); - //constraints.erase(c); - progress = true; - } - } else { - Types::iterator ot = types.find(o); - if (ot != types.end()) { - subst[t] = ot->second; - progress = true; - } - } - c = next; - } - for (Substitutions::iterator s = subst.begin(); s != subst.end(); ++s) { - //std::cout << "Subst : " << s->first->str() << " => " << s->second->str() << endl; - for (Constraints::iterator c = constraints.begin(); c != constraints.end(); ++c) { - AType* objT = c->second; - if (objT == s->first && objT != s->second) { - c->second = s->second; - progress = true; - } - progress = substitute(c->second, s->first, s->second) || progress; - } - } + for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) { + TEnv::Constraints::iterator next = c; ++next; + if (*c->first == *s) c->first = t; + if (*c->second == *s) c->second = t; + substitute(c->first, s, t); + substitute(c->second, s, t); + c = next; } - //std::cout << "======== Done Unifying Types " << constraints.size() << endl; - Constraints::iterator i = constraints.find(root); - if (i != constraints.end()) { - types[root] = i->second; - return i->second; +} + +TSubst +TEnv::unify(const Constraints& constraints) // TAPL 22.4 +{ + if (constraints.empty()) return TSubst(); + AType* s = constraints.begin()->first; + AType* t = constraints.begin()->second; + Constraints cp = constraints; + cp.erase(cp.begin()); + + if (*s == *t) { + return unify(cp); + } else if (s->var && !t->contains(s)) { + substConstraints(cp, s, t); + return compose(unify(cp), TSubst(s, t)); + } else if (t->var && !s->contains(t)) { + substConstraints(cp, t, s); + return compose(unify(cp), TSubst(t, s)); + } else if (s->isForm("Fn") && t->isForm("Fn")) { + AType* s1 = dynamic_cast<AType*>(s->tup[1]); + AType* t1 = dynamic_cast<AType*>(t->tup[1]); + AType* s2 = dynamic_cast<AType*>(s->tup[2]); + AType* t2 = dynamic_cast<AType*>(t->tup[2]); + assert(s1 && t1 && s2 && t2); + cp.push_back(make_pair(s1, t1)); + cp.push_back(make_pair(s2, t2)); + return unify(cp); + } else { + throw TypeError("Type unification failed"); } - return NULL; +} + +void +TEnv::apply(const TSubst& substs) +{ + FOREACH(TSubst::const_iterator, s, substs) + FOREACH(Types::iterator, t, types) + if (*t->second == *s->first) + t->second = s->second; } @@ -648,8 +692,7 @@ Value* ASTSymbol::compile(CEnv& cenv) { Value** v = cenv.vals.ref(this); - if (v) - return *v; + if (v) return *v; AST** c = cenv.code.ref(this); if (c) { @@ -879,16 +922,18 @@ main() cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false)); while (1) { - std::cout << "(= "; + std::cout << "() "; std::cout.flush(); SExp exp = readExpression(std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; try { - AST* body = parseExpression(penv, exp); - AType* bodyT = cenv.tenv.unify(body); + AST* body = parseExpression(penv, exp); // Parse input + body->constrain(cenv.tenv); // Constrain types + cenv.tenv.solve(); // Solve and apply type constraints + AType* bodyT = cenv.tenv.type(body); if (!bodyT) throw TypeError("REPL call to untyped body"); if (bodyT->var) throw TypeError("REPL call to variable typed body"); @@ -900,7 +945,6 @@ main() Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); - try { Value* retVal = body->compile(cenv); cenv.builder.CreateRet(retVal); // Finish function @@ -910,22 +954,18 @@ main() f->eraseFromParent(); // Error reading body, remove function throw e; } - void* fp = engine->getPointerToFunction(f); if (bodyT->ctype == Type::Int32Ty) - std::cout << " " << ((int32_t (*)())fp)(); + std::cout << "; " << ((int32_t (*)())fp)(); else if (bodyT->ctype == Type::FloatTy) - std::cout << " " << ((float (*)())fp)(); + std::cout << "; " << ((float (*)())fp)(); else if (bodyT->ctype == Type::Int1Ty) - std::cout << " " << ((bool (*)())fp)(); - else - std::cout << " ?"; - + std::cout << "; " << ((bool (*)())fp)(); } else { Value* val = body->compile(cenv); - std::cout << " " << val; + std::cout << "; " << val; } - std::cout << " : " << cenv.tenv.type(body)->str() << ")" << endl; + std::cout << " : " << cenv.tenv.type(body)->str() << endl; } catch (Error e) { std::cerr << "Error: " << e.what() << endl; |