aboutsummaryrefslogtreecommitdiffstats
path: root/tuplr.cpp
diff options
context:
space:
mode:
authorDavid Robillard <d@drobilla.net>2009-03-05 18:14:43 +0000
committerDavid Robillard <d@drobilla.net>2009-03-05 18:14:43 +0000
commita61a8571f2f1490b277d3c4235c7050ccf21cdc3 (patch)
tree2d61d99c0a9f4847fe4c72e4cdfaceeee96bde13 /tuplr.cpp
parent51769d4de84e0122b4dc388592c8b069e218eabf (diff)
downloadresp-a61a8571f2f1490b277d3c4235c7050ccf21cdc3.tar.gz
resp-a61a8571f2f1490b277d3c4235c7050ccf21cdc3.tar.bz2
resp-a61a8571f2f1490b277d3c4235c7050ccf21cdc3.zip
Proper type inferencing for functions (and more generic) (type the identity function polymorphically correctly).
git-svn-id: http://svn.drobilla.net/resp/tuplr@44 ad02d1e2-f140-0410-9f75-f8b11f17cedd
Diffstat (limited to 'tuplr.cpp')
-rw-r--r--tuplr.cpp72
1 files changed, 42 insertions, 30 deletions
diff --git a/tuplr.cpp b/tuplr.cpp
index ff280df..94bd2a5 100644
--- a/tuplr.cpp
+++ b/tuplr.cpp
@@ -275,16 +275,16 @@ struct Funcs : public list< pair<AType*, Function*> > {
/// Closure (first-class function with captured lexical bindings)
struct ASTClosure : public ASTTuple {
ASTClosure(ASTTuple* p, AST* b, const string& n="")
- : ASTTuple(0, p, b), prot(p), func(0), name(n) {}
+ : ASTTuple(0, p, b), name(n) {}
bool operator==(const AST& rhs) const { return this == &rhs; }
string str() const { return (format("%1%") % this).str(); }
void constrain(TEnv& tenv) const;
void lift(CEnv& cenv);
Value* compile(CEnv& cenv);
- ASTTuple* const prot;
+ ASTTuple* prot() const { return dynamic_cast<ASTTuple*>(at(1)); }
private:
- Function* func;
- string name;
+ Funcs funcs;
+ string name;
};
/// Function call/application, e.g. "(func arg1 arg2)"
@@ -498,6 +498,7 @@ struct TEnv {
namedTypes[sym] = new AType(penv.sym(name), type);
}
void constrain(const AST* o, AType* t) {
+ assert(!dynamic_cast<const AType*>(o));
constraints.push_back(make_pair(type(o), t));
}
void solve() { apply(unify(constraints)); }
@@ -520,17 +521,18 @@ ASTTuple::constrain(TEnv& tenv) const
(*p)->constrain(tenv);
t->push_back(tenv.type(*p));
}
- tenv.constrain(tenv.type(this), t);
+ tenv.constrain(this, t);
}
void
ASTClosure::constrain(TEnv& tenv) const
{
- prot->constrain(tenv);
+ at(1)->constrain(tenv);
at(2)->constrain(tenv);
+ AType* protT = tenv.type(at(1));
AType* bodyT = tenv.type(at(2));
- tenv.constrain(this, new AType(
- ASTTuple(tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0)));
+ tenv.constrain(this, new AType(ASTTuple(
+ tenv.penv.sym("Fn"), protT, bodyT, 0)));
}
void
@@ -539,8 +541,11 @@ ASTCall::constrain(TEnv& tenv) const
FOREACH(const_iterator, p, *this)
(*p)->constrain(tenv);
AType* retT = tenv.type(this);
+ AType* argsT = new AType(ASTTuple());
+ for (size_t i = 1; i < size(); ++i)
+ argsT->push_back(tenv.type(at(i)));
tenv.constrain(at(0), new AType(ASTTuple(
- tenv.penv.sym("Fn"), tenv.var(), retT, NULL)));
+ tenv.penv.sym("Fn"), argsT, retT, NULL)));
}
void
@@ -686,15 +691,13 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4
} else if (t->var && !s->contains(t)) {
substConstraints(cp, t, s);
return compose(unify(cp), TSubst(t, s));
- } else if ((s->isForm("Fn") && t->isForm("Fn"))
- || (s->isForm("Pair") && t->isForm("Pair"))) {
- AType* s1 = dynamic_cast<AType*>(s->at(1));
- AType* t1 = dynamic_cast<AType*>(t->at(1));
- AType* s2 = dynamic_cast<AType*>(s->at(2));
- AType* t2 = dynamic_cast<AType*>(t->at(2));
- assert(s1 && t1 && s2 && t2);
- cp.push_back(make_pair(s1, t1));
- cp.push_back(make_pair(s2, t2));
+ } else if (s->size() == t->size()) {
+ for (size_t i = 0; i < s->size(); ++i) {
+ AType* si = dynamic_cast<AType*>(s->at(i));
+ AType* ti = dynamic_cast<AType*>(t->at(i));
+ if (si && ti)
+ cp.push_back(make_pair(si, ti));
+ }
return unify(cp);
} else {
throw Error("Type unification failed");
@@ -708,6 +711,8 @@ TEnv::apply(const TSubst& substs)
FOREACH(Types::iterator, t, types)
if (*t->second == *s->first)
t->second = s->second;
+ else
+ substitute(t->second, s->first, s->second);
}
@@ -804,29 +809,33 @@ Value*
ASTSymbol::compile(CEnv& cenv)
{
AST** c = cenv.code.ref(this);
- if (!c) throw Error((string("Undefined symbol: ") + cppstr).c_str());
+ if (!c) throw Error((string("undefined symbol `") + cppstr + "'").c_str());
return cenv.compile(*c);
}
void
ASTClosure::lift(CEnv& cenv)
{
- if (cenv.tenv.type(at(2))->var || !cenv.tenv.type(prot)->concrete()) {
- std::cerr << "Closure has variable type, not lifting" << endl;
+ AType* type = cenv.tenv.type(this);
+ if (!type->concrete()) {
+ std::cerr << "closure is untyped, not lifting" << endl;
return;
}
- assert(!func);
+
+ if (funcs.find(type))
+ return;
+
cenv.push();
// Write function declaration
string name = this->name == "" ? cenv.gensym("_fn") : this->name;
- Function* f = compileFunction(cenv, name, cenv.tenv.type(at(2))->type(), *prot);
+ Function* f = compileFunction(cenv, name, cenv.tenv.type(at(2))->type(), *prot());
BasicBlock* bb = BasicBlock::Create("entry", f);
cenv.builder.SetInsertPoint(bb);
// Bind argument values in CEnv
vector<Value*> args;
- const_iterator p = prot->begin();
+ const_iterator p = prot()->begin();
for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a);
@@ -836,20 +845,19 @@ ASTClosure::lift(CEnv& cenv)
Value* retVal = cenv.compile(at(2));
cenv.builder.CreateRet(retVal); // Finish function
cenv.optimise(*f);
- func = f;
+ funcs.insert(type, f);
} catch (Error& e) {
f->eraseFromParent(); // Error reading body, remove function
throw e;
}
- assert(func);
cenv.pop();
}
Value*
ASTClosure::compile(CEnv& cenv)
{
- return func;
+ return funcs.find(cenv.tenv.type(this));
}
void
@@ -869,13 +877,13 @@ ASTCall::lift(CEnv& cenv)
// Extend environment with bound and typed parameters
cenv.push();
- if (c->prot->size() < size() - 1)
+ if (c->prot()->size() < size() - 1)
throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), exp.loc);
- if (c->prot->size() > size() - 1)
+ if (c->prot()->size() > size() - 1)
throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), exp.loc);
for (size_t i = 1; i < size(); ++i)
- cenv.code.def(c->prot->at(i-1), at(i));
+ cenv.code.def(c->prot()->at(i-1), at(i));
c->lift(cenv); // Lift called closure
cenv.pop(); // Restore environment
@@ -1136,6 +1144,10 @@ eval(CEnv& cenv, ExecutionEngine* engine, const string& name, istream& is)
// Create function for top-level of program
ASTTuple prot;
const Type* ctype = resultType->type();
+ if (!ctype) {
+ std::cerr << "program body has non-compilable type" << endl;
+ return 2;
+ }
assert(ctype);
Function* f = compileFunction(cenv, cenv.gensym("input"), ctype, prot);
BasicBlock* bb = BasicBlock::Create("entry", f);