From a61a8571f2f1490b277d3c4235c7050ccf21cdc3 Mon Sep 17 00:00:00 2001 From: David Robillard Date: Thu, 5 Mar 2009 18:14:43 +0000 Subject: Proper type inferencing for functions (and more generic) (type the identity function polymorphically correctly). git-svn-id: http://svn.drobilla.net/resp/tuplr@44 ad02d1e2-f140-0410-9f75-f8b11f17cedd --- tuplr.cpp | 72 +++++++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 42 insertions(+), 30 deletions(-) (limited to 'tuplr.cpp') diff --git a/tuplr.cpp b/tuplr.cpp index ff280df..94bd2a5 100644 --- a/tuplr.cpp +++ b/tuplr.cpp @@ -275,16 +275,16 @@ struct Funcs : public list< pair > { /// Closure (first-class function with captured lexical bindings) struct ASTClosure : public ASTTuple { ASTClosure(ASTTuple* p, AST* b, const string& n="") - : ASTTuple(0, p, b), prot(p), func(0), name(n) {} + : ASTTuple(0, p, b), name(n) {} bool operator==(const AST& rhs) const { return this == &rhs; } string str() const { return (format("%1%") % this).str(); } void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); - ASTTuple* const prot; + ASTTuple* prot() const { return dynamic_cast(at(1)); } private: - Function* func; - string name; + Funcs funcs; + string name; }; /// Function call/application, e.g. "(func arg1 arg2)" @@ -498,6 +498,7 @@ struct TEnv { namedTypes[sym] = new AType(penv.sym(name), type); } void constrain(const AST* o, AType* t) { + assert(!dynamic_cast(o)); constraints.push_back(make_pair(type(o), t)); } void solve() { apply(unify(constraints)); } @@ -520,17 +521,18 @@ ASTTuple::constrain(TEnv& tenv) const (*p)->constrain(tenv); t->push_back(tenv.type(*p)); } - tenv.constrain(tenv.type(this), t); + tenv.constrain(this, t); } void ASTClosure::constrain(TEnv& tenv) const { - prot->constrain(tenv); + at(1)->constrain(tenv); at(2)->constrain(tenv); + AType* protT = tenv.type(at(1)); AType* bodyT = tenv.type(at(2)); - tenv.constrain(this, new AType( - ASTTuple(tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0))); + tenv.constrain(this, new AType(ASTTuple( + tenv.penv.sym("Fn"), protT, bodyT, 0))); } void @@ -539,8 +541,11 @@ ASTCall::constrain(TEnv& tenv) const FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* retT = tenv.type(this); + AType* argsT = new AType(ASTTuple()); + for (size_t i = 1; i < size(); ++i) + argsT->push_back(tenv.type(at(i))); tenv.constrain(at(0), new AType(ASTTuple( - tenv.penv.sym("Fn"), tenv.var(), retT, NULL))); + tenv.penv.sym("Fn"), argsT, retT, NULL))); } void @@ -686,15 +691,13 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4 } else if (t->var && !s->contains(t)) { substConstraints(cp, t, s); return compose(unify(cp), TSubst(t, s)); - } else if ((s->isForm("Fn") && t->isForm("Fn")) - || (s->isForm("Pair") && t->isForm("Pair"))) { - AType* s1 = dynamic_cast(s->at(1)); - AType* t1 = dynamic_cast(t->at(1)); - AType* s2 = dynamic_cast(s->at(2)); - AType* t2 = dynamic_cast(t->at(2)); - assert(s1 && t1 && s2 && t2); - cp.push_back(make_pair(s1, t1)); - cp.push_back(make_pair(s2, t2)); + } else if (s->size() == t->size()) { + for (size_t i = 0; i < s->size(); ++i) { + AType* si = dynamic_cast(s->at(i)); + AType* ti = dynamic_cast(t->at(i)); + if (si && ti) + cp.push_back(make_pair(si, ti)); + } return unify(cp); } else { throw Error("Type unification failed"); @@ -708,6 +711,8 @@ TEnv::apply(const TSubst& substs) FOREACH(Types::iterator, t, types) if (*t->second == *s->first) t->second = s->second; + else + substitute(t->second, s->first, s->second); } @@ -804,29 +809,33 @@ Value* ASTSymbol::compile(CEnv& cenv) { AST** c = cenv.code.ref(this); - if (!c) throw Error((string("Undefined symbol: ") + cppstr).c_str()); + if (!c) throw Error((string("undefined symbol `") + cppstr + "'").c_str()); return cenv.compile(*c); } void ASTClosure::lift(CEnv& cenv) { - if (cenv.tenv.type(at(2))->var || !cenv.tenv.type(prot)->concrete()) { - std::cerr << "Closure has variable type, not lifting" << endl; + AType* type = cenv.tenv.type(this); + if (!type->concrete()) { + std::cerr << "closure is untyped, not lifting" << endl; return; } - assert(!func); + + if (funcs.find(type)) + return; + cenv.push(); // Write function declaration string name = this->name == "" ? cenv.gensym("_fn") : this->name; - Function* f = compileFunction(cenv, name, cenv.tenv.type(at(2))->type(), *prot); + Function* f = compileFunction(cenv, name, cenv.tenv.type(at(2))->type(), *prot()); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); // Bind argument values in CEnv vector args; - const_iterator p = prot->begin(); + const_iterator p = prot()->begin(); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) cenv.vals.def(dynamic_cast(*p), &*a); @@ -836,20 +845,19 @@ ASTClosure::lift(CEnv& cenv) Value* retVal = cenv.compile(at(2)); cenv.builder.CreateRet(retVal); // Finish function cenv.optimise(*f); - func = f; + funcs.insert(type, f); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function throw e; } - assert(func); cenv.pop(); } Value* ASTClosure::compile(CEnv& cenv) { - return func; + return funcs.find(cenv.tenv.type(this)); } void @@ -869,13 +877,13 @@ ASTCall::lift(CEnv& cenv) // Extend environment with bound and typed parameters cenv.push(); - if (c->prot->size() < size() - 1) + if (c->prot()->size() < size() - 1) throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), exp.loc); - if (c->prot->size() > size() - 1) + if (c->prot()->size() > size() - 1) throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), exp.loc); for (size_t i = 1; i < size(); ++i) - cenv.code.def(c->prot->at(i-1), at(i)); + cenv.code.def(c->prot()->at(i-1), at(i)); c->lift(cenv); // Lift called closure cenv.pop(); // Restore environment @@ -1136,6 +1144,10 @@ eval(CEnv& cenv, ExecutionEngine* engine, const string& name, istream& is) // Create function for top-level of program ASTTuple prot; const Type* ctype = resultType->type(); + if (!ctype) { + std::cerr << "program body has non-compilable type" << endl; + return 2; + } assert(ctype); Function* f = compileFunction(cenv, cenv.gensym("input"), ctype, prot); BasicBlock* bb = BasicBlock::Create("entry", f); -- cgit v1.2.1