From 6671a54283fc7d9323fc14c9feee525f01b2821d Mon Sep 17 00:00:00 2001 From: David Robillard Date: Mon, 26 Jan 2009 07:37:38 +0000 Subject: Somewhat functional type inference. git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@16 ad02d1e2-f140-0410-9f75-f8b11f17cedd --- ll.cpp | 179 ++++++++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 104 insertions(+), 75 deletions(-) (limited to 'll.cpp') diff --git a/ll.cpp b/ll.cpp index cc971a5..0dfa04e 100644 --- a/ll.cpp +++ b/ll.cpp @@ -18,6 +18,7 @@ * along with This program. If not, see . */ +#include #include #include #include @@ -176,7 +177,7 @@ struct ASTTuple : public AST { /// Type Expression ::= (TName TExpr*) | ?Num struct AType : public ASTTuple { AType(const vector& t) : ASTTuple(t), var(false), ctype(0) {} - AType(unsigned i) : var(true), ctype(0), id(id) {} + AType(unsigned i) : var(true), ctype(0), id(i) {} AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) { tup.push_back(n); } @@ -189,6 +190,15 @@ struct AType : public ASTTuple { } void constrain(TEnv& tenv) const {} Value* compile(CEnv& cenv) { return NULL; } + bool concrete() const { + if (var) return false; + FOREACH(vector::const_iterator, t, tup) { + AType* kid = dynamic_cast(*t); + if (kid && !kid->concrete()) + return false; + } + return true; + } bool var; const Type* ctype; unsigned id; @@ -212,7 +222,7 @@ struct ASTLiteral : public AST { struct ASTClosure : public AST { ASTClosure(ASTTuple* p, AST* b) : prot(p), body(b), func(0) {} bool operator==(const AST& rhs) const { return this == &rhs; } - string str() const { return "(fn)"; } + string str() const { ostringstream s; s << this; return s.str(); } void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); @@ -381,16 +391,19 @@ struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; /// Type-Time Environment struct TEnv { - TEnv(PEnv& p) : penv(p), varID(0) {} + TEnv(PEnv& p) : penv(p), varID(1) {} typedef map Types; typedef multimap Constraints; AType* var() { return new AType(varID++); } AType* type(const AST* ast) { Types::iterator t = types.find(ast); - if (t != types.end()) + if (t != types.end()) { return t->second; - else - return (types[ast] = var()); + } else { + AType* tvar = var(); + constrain(ast, tvar); + return tvar; + } } AType* named(const string& name) const { Types::const_iterator i = namedTypes.find(penv.sym(name)); @@ -404,8 +417,7 @@ struct TEnv { void constrain(const AST* ast, AType* type) { constraints.insert(make_pair(ast, type)); } - void unify(); - + AType* unify(AST* root); PEnv& penv; Types types; Types namedTypes; @@ -415,6 +427,18 @@ struct TEnv { #define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) +vector +tuple(AST* ast, ...) +{ + vector tup(1, ast); + va_list args; + va_start(args, ast); + for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*)) + tup.push_back(a); + va_end(args); + return tup; +} + void ASTTuple::constrain(TEnv& tenv) const { @@ -423,7 +447,7 @@ ASTTuple::constrain(TEnv& tenv) const (*p)->constrain(tenv); AType* tvar = tenv.var(); texp.push_back(tvar); - tenv.constrain(tvar, tenv.type(*p)); + tenv.constrain(*p, tvar); } tenv.constrain(this, new AType(texp)); } @@ -433,31 +457,21 @@ ASTClosure::constrain(TEnv& tenv) const { prot->constrain(tenv); body->constrain(tenv); - vector texp(3); - texp[0] = tenv.penv.sym("Fn"); - texp[1] = prot; - texp[2] = body; - AType* tvar = tenv.var(); - tenv.constrain(texp[2], tvar); - tenv.constrain(this, tvar); + AType* bodyT = tenv.var(); + tenv.constrain(body, bodyT); + tenv.constrain(this, new AType(tuple( + tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0))); } void ASTCall::constrain(TEnv& tenv) const { - ASTTuple::constrain(tenv); -#if 0 - AST* callee = tup[0]; - ASTSymbol* sym = dynamic_cast(tup[0]); - if (sym) { - AST** val = tenv.code.ref(sym); - if (val) - callee = *val; - } - ASTClosure* c = dynamic_cast(callee); - if (!c) throw TypeError("Call to non-closure"); - tenv.contraints[this] = c->body->type(tenv); -#endif + FOREACH(vector::const_iterator, p, tup) + (*p)->constrain(tenv); + AType* retT = tenv.var(); + vector texp = tuple(tenv.penv.sym("Fn"), tenv.var(), retT, NULL); + tenv.constrain(new AType(texp), tenv.var()); + tenv.constrain(this, retT); } void @@ -503,52 +517,71 @@ ASTPrimitive::constrain(TEnv& tenv) const } } -void -TEnv::unify() +static bool +substitute(ASTTuple* tup, AST* from, AST* to) { - typedef map Substitutions; - bool progress = false; - do { - progress = false; - //std::cout << "========" << endl; - Substitutions subst; + for (size_t i = 0; i < tup->tup.size(); ++i) { + if (*tup->tup[i] == *from) { + tup->tup[i] = to; + progress = true; + } + } + return progress; +} + +AType* +TEnv::unify(AST* root) +{ + root->constrain(*this); + assert(constraints.find(root) != constraints.end()); + //constrain(root, var()); + typedef map Substitutions; + Substitutions subst; + for (bool progress = true; progress; progress = false) { + //std::cout << "==== " << constraints.size() << endl; for (Constraints::iterator c = constraints.begin(); c != constraints.end();) { Constraints::iterator next = c; ++next; const AST* o = c->first; AType* t = c->second; - //std::cout << "Constraint: " << o->str() << " = " << t->str() << endl; - if (t->var) { - Types::iterator ot = types.find(o); - if (ot != types.end()) - subst[t] = ot->second; - } else { + //std::cout << "Constr : " << o->str() << " = " << t->str() << endl; + if (t->concrete()) { Types::iterator ot = types.find(o); if (ot == types.end()) { - //std::cout << "Resolve: " << o->str() << endl; + //std::cout << "Resolv : " << o->str() << endl; types.insert(make_pair(o, t)); - constraints.erase(c); + //constraints.erase(c); + progress = true; + } + } else { + Types::iterator ot = types.find(o); + if (ot != types.end()) { + subst[t] = ot->second; + progress = true; } } c = next; } - for (Substitutions::iterator s = subst.begin(); s != subst.end(); ++s) { + //std::cout << "Subst : " << s->first->str() << " => " << s->second->str() << endl; for (Constraints::iterator c = constraints.begin(); c != constraints.end(); ++c) { - if (c->second == s->first) { - //std::cout << c->second->str() << " => " << s->second->str() << endl; + AType* objT = c->second; + if (objT == s->first && objT != s->second) { c->second = s->second; progress = true; } + progress = substitute(c->second, s->first, s->second) || progress; } } - - } while (progress); - - //std::cout << "======== Done unification" << endl; - - constraints.clear(); + } + //std::cout << "======== Done Unifying Types " << constraints.size() << endl; + Constraints::iterator i = constraints.find(root); + if (i != constraints.end()) { + types[root] = i->second; + return i->second; + } + return NULL; } @@ -594,16 +627,16 @@ struct CEnv { Vals vals; }; -#define LITERAL(CT, VT, NAME, COMPILED) \ +#define LITERAL(CT, NAME, COMPILED) \ template<> Value* \ ASTLiteral::compile(CEnv& cenv) { return (COMPILED); } \ template<> void \ ASTLiteral::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } /// Literal template instantiations -LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, true)); -LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val)); -LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); +LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)); +LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)); +LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); static Function* compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT) @@ -893,29 +926,25 @@ main() cenv.tenv.name("Float", Type::FloatTy); while (1) { - std::cout << "(=>) "; + std::cout << "(= "; std::cout.flush(); SExp exp = readExpression(std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; try { - AST* body = parseExpression(penv, exp); - - body->constrain(cenv.tenv); - cenv.tenv.unify(); - - ASTTuple* prot = new ASTTuple(); - AType* bodyT = cenv.tenv.type(body); + AST* body = parseExpression(penv, exp); + AType* bodyT = cenv.tenv.unify(body); if (!bodyT) throw TypeError("REPL call to untyped body"); if (bodyT->var) throw TypeError("REPL call to variable typed body"); body->lift(cenv); - + if (bodyT->ctype) { // Create anonymous function to insert code into. - Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype); + ASTTuple* prot = new ASTTuple(); + Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); @@ -931,26 +960,26 @@ main() void* fp = engine->getPointerToFunction(f); if (bodyT->ctype == Type::Int32Ty) - std::cout << ((int32_t (*)())fp)(); + std::cout << " " <<((int32_t (*)())fp)(); else if (bodyT->ctype == Type::FloatTy) - std::cout << ((float (*)())fp)(); + std::cout << " " <<((float (*)())fp)(); else if (bodyT->ctype == Type::Int1Ty) - std::cout << ((bool (*)())fp)(); + std::cout << " " << ((bool (*)())fp)(); else - std::cout << "?"; + std::cout << " ?"; } else { Value* val = body->compile(cenv); - std::cout << val; + std::cout << " " << val; } - std::cout << " : " << cenv.tenv.type(body)->str() << endl; + std::cout << " : " << cenv.tenv.type(body)->str() << ")" << endl; } catch (Error e) { std::cerr << "Error: " << e.what() << endl; } } - std::cout << "Generated code:" << endl; + std::cout << endl << "Generated code:" << endl; module->dump(); return 0; } -- cgit v1.2.1