aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Robillard <d@drobilla.net>2010-04-13 02:28:56 +0000
committerDavid Robillard <d@drobilla.net>2010-04-13 02:28:56 +0000
commit8675beae4f7a8415fc2e88451da95dc068719194 (patch)
tree599de9b6730a14035a25f7d9e0467f96866185ed
parent1f988f420ba3827941886962680f3e2ad6f01740 (diff)
downloadresp-8675beae4f7a8415fc2e88451da95dc068719194.tar.gz
resp-8675beae4f7a8415fc2e88451da95dc068719194.tar.bz2
resp-8675beae4f7a8415fc2e88451da95dc068719194.zip
Restructure as a source translation based compiler.
Implement support for closures (via lambda lifting phase). git-svn-id: http://svn.drobilla.net/resp/resp@254 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r--Makefile3
-rw-r--r--src/c.cpp64
-rw-r--r--src/compile.cpp30
-rw-r--r--src/constrain.cpp33
-rw-r--r--src/lift.cpp188
-rw-r--r--src/llvm.cpp88
-rw-r--r--src/repl.cpp109
-rw-r--r--src/resp.cpp32
-rw-r--r--src/resp.hpp245
-rw-r--r--src/unify.cpp49
-rwxr-xr-xtest.sh16
-rw-r--r--test/closure.resp5
-rw-r--r--test/def.resp4
-rw-r--r--test/deffn.resp3
-rw-r--r--test/inlinefn.resp1
-rw-r--r--test/nest.resp2
16 files changed, 568 insertions, 304 deletions
diff --git a/Makefile b/Makefile
index ad101f4..bdec095 100644
--- a/Makefile
+++ b/Makefile
@@ -26,14 +26,15 @@ OBJECTS = \
build/compile.o \
build/constrain.o \
build/cps.o \
+ build/depoly.o \
build/gc.o \
build/lex.o \
build/lift.o \
build/parse.o \
build/pprint.o \
build/repl.o \
- build/tlsf.o \
build/resp.o \
+ build/tlsf.o \
build/unify.o \
build/resp_gc.o
diff --git a/src/c.cpp b/src/c.cpp
index 7dbb6c2..c456132 100644
--- a/src/c.cpp
+++ b/src/c.cpp
@@ -42,7 +42,9 @@ static inline Function* llFunc(CFunc f) { return static_cast<Function*>(f); }
static const Type*
llType(const AType* t)
{
- if (t->kind == AType::PRIM) {
+ if (t == NULL) {
+ return NULL;
+ } else if (t->kind == AType::PRIM) {
if (t->head()->str() == "Nothing") return new string("void");
if (t->head()->str() == "Bool") return new string("bool");
if (t->head()->str() == "Int") return new string("int");
@@ -66,8 +68,19 @@ llType(const AType* t)
*ret += ")";
return ret;
+ } else if (t->kind == AType::EXPR && t->head()->str() == "Tup") {
+ Type* ret = new Type("struct { void* me; ");
+ for (AType::const_iterator i = t->begin() + 1; i != t->end(); ++i) {
+ const Type* lt = llType((*i)->to<const AType*>());
+ if (!lt)
+ return NULL;
+ ret->append("; ");
+ ret->append(*lt);
+ }
+ ret->append("}*");
+ return ret;
}
- return NULL; // non-primitive type
+ return new Type("void*");
}
@@ -116,7 +129,7 @@ struct CEngine : public Engine {
return f;
}
- void finishFunction(CEnv& cenv, CFunc f, const AType* retT, CVal ret) {
+ void finishFunction(CEnv& cenv, CFunc f, CVal ret) {
out += "return " + *(Value*)ret + ";\n}\n\n";
}
@@ -124,7 +137,7 @@ struct CEngine : public Engine {
cenv.err << "C backend does not support JIT (eraseFunction)" << endl;
}
- CVal compileCall(CEnv& cenv, CFunc func, const vector<CVal>& args) {
+ CVal compileCall(CEnv& cenv, CFunc func, const AType* funcT, const vector<CVal>& args) {
Value* varname = new string(cenv.penv.gensymstr("x"));
Function* f = llFunc(func);
out += (format("const %s %s = %s(") % f->returnType % *varname % f->name).str();
@@ -134,21 +147,21 @@ struct CEngine : public Engine {
return varname;
}
- CFunc compileFunction(CEnv& cenv, AFn* fn, const AType& argsT);
+ CFunc compileFunction(CEnv& cenv, AFn* fn, const AType* type);
CVal compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields);
CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
CVal compileLiteral(CEnv& cenv, AST* lit);
CVal compilePrimitive(CEnv& cenv, APrimitive* prim);
CVal compileIf(CEnv& cenv, AIf* aif);
- CVal compileGlobal(CEnv& cenv, AType* type, const string& name, CVal val);
+ CVal compileGlobal(CEnv& cenv, const AType* type, const string& name, CVal val);
CVal getGlobal(CEnv& cenv, CVal val);
void writeModule(CEnv& cenv, std::ostream& os) {
os << out;
}
- const string call(CEnv& cenv, CFunc f, AType* retT) {
+ const string call(CEnv& cenv, CFunc f, const AType* retT) {
cenv.err << "C backend does not support JIT (call)" << endl;
return "";
}
@@ -185,28 +198,14 @@ CEngine::compileLiteral(CEnv& cenv, AST* lit)
}
CFunc
-CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)
+CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType* type)
{
- CEngine* engine = reinterpret_cast<CEngine*>(cenv.engine());
- AType* genericType = cenv.type(fn);
- AType* thisType = genericType;
- Subst argsSubst;
-
- // Build and apply substitution to get concrete type for this call
- if (!genericType->concrete()) {
- argsSubst = cenv.tenv.buildSubst(genericType, argsT);
- thisType = argsSubst.apply(genericType)->as<AType*>();
- }
-
- THROW_IF(!thisType->concrete(), fn->loc,
- string("call has non-concrete type %1%\n") + thisType->str());
-
- Object::pool.addRoot(thisType);
- CFunc f = fn->impls.find(thisType);
- if (f)
- return f;
+ assert(type->concrete());
- ATuple* protT = thisType->prot();
+ CEngine* engine = reinterpret_cast<CEngine*>(cenv.engine());
+ const AType* argsT = type->prot()->as<const AType*>();
+ const AType* retT = type->last()->as<const AType*>();
+ Subst argsSubst = cenv.tenv.buildSubst(type, *argsT);
vector<string> argNames;
for (ATuple::const_iterator i = fn->prot()->begin(); i != fn->prot()->end(); ++i)
@@ -214,9 +213,7 @@ CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)
// Write function declaration
const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name;
- f = llFunc(cenv.engine()->startFunction(cenv, name,
- thisType->last()->to<AType*>(),
- *protT, argNames));
+ Function* f = llFunc(cenv.engine()->startFunction(cenv, name, retT, *argsT, argNames));
cenv.push();
Subst oldSubst = cenv.tsubst;
@@ -225,7 +222,7 @@ CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)
// Bind argument values in CEnv
vector<Value*> args;
AFn::const_iterator p = fn->prot()->begin();
- ATuple::const_iterator pT = protT->begin();
+ ATuple::const_iterator pT = argsT->begin();
for (; p != fn->prot()->end(); ++p, ++pT) {
AType* t = (*pT)->as<AType*>();
const Type* lt = llType(t);
@@ -235,11 +232,10 @@ CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)
// Write function body
try {
- fn->impls.push_back(make_pair(thisType, f));
CVal retVal = NULL;
for (AFn::iterator i = fn->begin() + 2; i != fn->end(); ++i)
retVal = (*i)->compile(cenv);
- cenv.engine()->finishFunction(cenv, f, cenv.type(fn->last()), retVal);
+ cenv.engine()->finishFunction(cenv, f, retVal);
} catch (Error& e) {
cenv.pop();
throw e;
@@ -314,7 +310,7 @@ CEngine::compilePrimitive(CEnv& cenv, APrimitive* prim)
}
CVal
-CEngine::compileGlobal(CEnv& cenv, AType* type, const string& name, CVal val)
+CEngine::compileGlobal(CEnv& cenv, const AType* type, const string& name, CVal val)
{
return NULL;
}
diff --git a/src/compile.cpp b/src/compile.cpp
index 4f24994..ee92dbd 100644
--- a/src/compile.cpp
+++ b/src/compile.cpp
@@ -43,33 +43,29 @@ ASymbol::compile(CEnv& cenv) throw()
CVal
AFn::compile(CEnv& cenv) throw()
{
- return impls.find(cenv.type(this));
+ const AType* type = cenv.type(this);
+ CFunc f = cenv.findImpl(this, type);
+ if (!f) {
+ f = cenv.engine()->compileFunction(cenv, this, type);
+ cenv.vals.def(cenv.penv.sym(name), f);
+ cenv.addImpl(this, f);
+ }
+ return f;
}
CVal
ACall::compile(CEnv& cenv) throw()
{
- AFn* c = cenv.resolve(head())->to<AFn*>();
-
- if (!c) return NULL; // Primitive
-
- AType protT(loc);
- for (const_iterator i = begin() + 1; i != end(); ++i)
- protT.push_back(cenv.type(*i));
-
- AType fnT(loc);
- fnT.push_back(cenv.tenv.Fn);
- fnT.push_back(&protT);
- fnT.push_back(cenv.type(this));
+ CFunc f = (*begin())->compile(cenv);
- CFunc f = c->impls.find(&fnT);
- THROW_IF(!f, loc, (format("callee failed to compile for type %1%") % fnT.str()).str());
+ if (!f)
+ f = cenv.currentFn; // Recursive call (callee defined as a stub)
vector<CVal> args;
for (const_iterator e = begin() + 1; e != end(); ++e)
args.push_back((*e)->compile(cenv));
- return cenv.engine()->compileCall(cenv, f, args);
+ return cenv.engine()->compileCall(cenv, f, cenv.type(head()), args);
}
CVal
@@ -83,7 +79,7 @@ ADef::compile(CEnv& cenv) throw()
cenv.lock(this);
}
cenv.vals.def(sym(), val);
- return val;
+ return NULL;
}
CVal
diff --git a/src/constrain.cpp b/src/constrain.cpp
index c815bac..024dce4 100644
--- a/src/constrain.cpp
+++ b/src/constrain.cpp
@@ -42,7 +42,7 @@ AString::constrain(TEnv& tenv, Constraints& c) const throw(Error)
void
ASymbol::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
- AType** ref = tenv.ref(this);
+ const AType** ref = tenv.ref(this);
THROW_IF(!ref, loc, (format("undefined symbol `%1%'") % cppstr).str());
c.constrain(tenv, this, *ref);
}
@@ -53,7 +53,7 @@ ATuple::constrain(TEnv& tenv, Constraints& c) const throw(Error)
AType* t = tup<AType>(loc, NULL);
FOREACHP(const_iterator, p, this) {
(*p)->constrain(tenv, c);
- t->push_back(tenv.var(*p));
+ t->push_back(const_cast<AType*>(tenv.var(*p)));
}
c.constrain(tenv, this, t);
}
@@ -72,9 +72,9 @@ AFn::constrain(TEnv& tenv, Constraints& c) const throw(Error)
THROW_IF(defs.count(sym) != 0, sym->loc,
(format("duplicate parameter `%1%'") % sym->str()).str());
defs.insert(sym);
- AType* tvar = tenv.fresh(sym);
+ const AType* tvar = tenv.fresh(sym);
frame.push_back(make_pair(sym, tvar));
- protT->push_back(tvar);
+ protT->push_back(const_cast<AType*>(tvar));
}
const_iterator i = begin() + 1;
@@ -95,13 +95,12 @@ AFn::constrain(TEnv& tenv, Constraints& c) const throw(Error)
tenv.push(frame);
- c.constrain(tenv, this, tenv.var());
AST* exp = NULL;
for (i = begin() + 2; i != end(); ++i)
(exp = *i)->constrain(tenv, c);
- AType* bodyT = tenv.var(exp);
- AType* fnT = tup<AType>(loc, tenv.Fn, protT, bodyT, 0);
+ const AType* bodyT = tenv.var(exp);
+ const AType* fnT = tup<const AType>(loc, tenv.Fn, protT, bodyT, 0);
Object::pool.addRoot(fnT);
tenv.pop();
@@ -127,10 +126,10 @@ ACall::constrain(TEnv& tenv, Constraints& c) const throw(Error)
(format("expected %1% arguments, got %2%") % numArgs % (size() - 1)).str());
}
- AType* retT = tenv.var();
- AType* argsT = tup<AType>(loc, 0);
+ const AType* retT = tenv.var();
+ AType* argsT = tup<AType>(loc, 0);
for (const_iterator i = begin() + 1; i != end(); ++i)
- argsT->push_back(tenv.var(*i));
+ argsT->push_back(const_cast<AType*>(tenv.var(*i)));
c.constrain(tenv, head(), tup<AType>(head()->loc, tenv.Fn, argsT, retT, 0));
c.constrain(tenv, this, retT);
@@ -143,7 +142,7 @@ ADef::constrain(TEnv& tenv, Constraints& c) const throw(Error)
const ASymbol* sym = this->sym();
THROW_IF(!sym, loc, "`def' has no symbol")
- AType* tvar = tenv.var(body());
+ const AType* tvar = tenv.var(body());
tenv.def(sym, tvar);
body()->constrain(tenv, c);
c.constrain(tenv, sym, tvar);
@@ -157,7 +156,7 @@ AIf::constrain(TEnv& tenv, Constraints& c) const throw(Error)
THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause")
for (const_iterator i = begin() + 1; i != end(); ++i)
(*i)->constrain(tenv, c);
- AType* retT = tenv.var(this);
+ const AType* retT = tenv.var(this);
for (const_iterator i = begin() + 1; true; ++i) {
const_iterator next = i;
++next;
@@ -178,7 +177,7 @@ ACons::constrain(TEnv& tenv, Constraints& c) const throw(Error)
AType* type = tup<AType>(loc, tenv.Tup, 0);
for (const_iterator i = begin() + 1; i != end(); ++i) {
(*i)->constrain(tenv, c);
- type->push_back(tenv.var(*i));
+ type->push_back(const_cast<AType*>(tenv.var(*i)));
}
c.constrain(tenv, this, type);
@@ -194,13 +193,13 @@ ADot::constrain(TEnv& tenv, Constraints& c) const throw(Error)
THROW_IF(!idx, loc, "the 2nd argument to `.' must be a literal integer");
obj->constrain(tenv, c);
- AType* retT = tenv.var(this);
+ const AType* retT = tenv.var(this);
c.constrain(tenv, this, retT);
AType* objT = tup<AType>(loc, tenv.Tup, 0);
for (int i = 0; i < idx->val; ++i)
- objT->push_back(tenv.var());
- objT->push_back(retT);
+ objT->push_back(const_cast<AType*>(tenv.var()));
+ objT->push_back(const_cast<AType*>(retT));
objT->push_back(new AType(obj->loc, AType::DOTS));
c.constrain(tenv, obj, objT);
}
@@ -228,7 +227,7 @@ APrimitive::constrain(TEnv& tenv, Constraints& c) const throw(Error)
i = begin();
- AType* var = NULL;
+ const AType* var = NULL;
switch (type) {
case ARITHMETIC:
if (size() < 3)
diff --git a/src/lift.cpp b/src/lift.cpp
index 94f70e0..6a53165 100644
--- a/src/lift.cpp
+++ b/src/lift.cpp
@@ -17,64 +17,194 @@
/** @file
* @brief Lift functions (compilation pass 1)
+ * After this pass:
+ * - All function definitions are top-level
+ * - All references to functions are replaced with references to
+ * a closure (a tuple with the function and necessary context)
*/
#include "resp.hpp"
using namespace std;
-void
-AFn::lift(CEnv& cenv) throw()
+AST*
+ASymbol::lift(CEnv& cenv, Code& code) throw()
{
+ if (!cenv.liftStack.empty() && cppstr == cenv.liftStack.top().fn->name) {
+ return cenv.penv.sym("me"); // Reference to innermost function
+ } else if (!cenv.penv.handler(true, cppstr)
+ && !cenv.penv.handler(false, cppstr)
+ && !cenv.code.innermost(this)) {
+
+ const int32_t index = cenv.liftStack.top().index(this);
+
+ // Replace symbol with code to access free variable from closure
+ return tup<ADot>(loc, cenv.penv.sym("."),
+ cenv.penv.sym("me"),
+ new ALiteral<int32_t>(index, Cursor()),
+ NULL);
+ } else {
+ return this;
+ }
+}
+
+AST*
+ATuple::lift(CEnv& cenv, Code& code) throw()
+{
+ ATuple* ret = new ATuple(*this);
+ iterator ri = ret->begin();
+ FOREACHP(const_iterator, t, this)
+ *ri++ = (*t)->lift(cenv, code);
+ cenv.setTypeSameAs(ret, this);
+ return ret;
+}
+
+AST*
+AFn::lift(CEnv& cenv, Code& code) throw()
+{
+ AFn* impl = new AFn(this);
+ const string nameBase = cenv.penv.gensymstr(((name != "") ? name : "fn").c_str());
+ impl->name = "_" + nameBase;
+
+ cenv.liftStack.push(CEnv::FreeVars(this, impl->name));
+
// Create a new stub environment frame for parameters
cenv.push();
+ AType::const_iterator tp = cenv.type(this)->prot()->begin();
for (const_iterator p = prot()->begin(); p != prot()->end(); ++p)
- cenv.def((*p)->as<ASymbol*>(), *p, NULL, NULL);
+ cenv.def((*p)->as<ASymbol*>(), *p, (*tp++)->as<AType*>(), NULL);
+
+ /* Add closure parameter with dummy name (undefined symbol).
+ * The name of this parameter will be changed to the name of this
+ * function after lifting the body (so recursive references correctly
+ * refer to this function by the closure parameter).
+ */
+ impl->prot()->push_front(cenv.penv.sym("_"));
+
// Lift body
- for (iterator i = begin() + 2; i != end(); ++i)
- (*i)->lift(cenv);
+ const AType* implRetT = NULL;
+ iterator ci = impl->begin() + 2;
+ for (const_iterator i = begin() + 2; i != end(); ++i, ++ci) {
+ *ci = (*i)->lift(cenv, code);
+ implRetT = cenv.type(*ci);
+ }
cenv.pop();
- AType* type = cenv.type(this);
- if (impls.find(type) || !type->concrete())
- return;
+ // Set name of closure parameter to actual name of this function
+ *impl->prot()->begin() = cenv.penv.sym("me");
- AType* protT = type->prot()->as<AType*>();
- cenv.engine()->compileFunction(cenv, this, *protT);
-}
+ // Create definition for implementation fn
+ ASymbol* implName = cenv.penv.sym(impl->name);
+ ADef* def = tup<ADef>(loc, cenv.penv.sym("def"), implName, impl, NULL);
+ code.push_back(def);
-void
-ACall::lift(CEnv& cenv) throw()
-{
- AFn* c = cenv.resolve(head())->to<AFn*>();
- AType argsT(loc);
+ AType* implT = new AType(*cenv.type(this)); // Type of the implementation function
+ AType* tupT = tup<AType>(loc, cenv.tenv.Tup, cenv.tenv.var(), NULL);
+ AType* consT = tup<AType>(loc, cenv.tenv.Tup, implT, NULL);
+ ACons* cons = tup<ACons>(loc, cenv.penv.sym("cons"), implName, NULL); // Closure
- // Lift arguments and build arguments type
- for (iterator i = begin() + 1; i != end(); ++i) {
- (*i)->lift(cenv);
- argsT.push_back(cenv.type(*i));
+ const CEnv::FreeVars& freeVars = cenv.liftStack.top();
+ for (CEnv::FreeVars::const_iterator i = freeVars.begin(); i != freeVars.end(); ++i) {
+ cons->push_back(*i);
+ tupT->push_back(const_cast<AType*>(cenv.type(*i)));
+ consT->push_back(const_cast<AType*>(cenv.type(*i)));
}
+ cenv.liftStack.pop();
- // Lift callee (if it's not a primitive)
- if (c)
- cenv.engine()->compileFunction(cenv, c, argsT);
+ implT->prot()->push_front(tupT);
+ *(implT->begin() + 2) = const_cast<AType*>(implRetT);
+
+ cenv.setType(impl, implT);
+ cenv.setType(cons, consT);
+
+ cenv.def(implName, impl, implT, NULL);
+ if (name != "")
+ cenv.def(cenv.penv.sym(name), this, consT, NULL);
+
+ return cons;
}
-void
-ADot::lift(CEnv& cenv) throw()
+AST*
+ACall::lift(CEnv& cenv, Code& code) throw()
{
- (*(begin() + 1))->lift(cenv);
+ ACall* copy = new ACall(this);
+ ATuple::iterator ri = copy->begin();
+
+ // Lift all children (callee and arguments, recursively)
+ for (const_iterator i = begin(); i != end(); ++i)
+ *ri++ = (*i)->lift(cenv, code);
+
+ ASymbol* sym = head()->to<ASymbol*>();
+ if (sym && !cenv.liftStack.empty() && sym->cppstr == cenv.liftStack.top().fn->name) {
+ /* Recursive call to innermost function, call implementation directly,
+ * reusing the current "me" closure parameter (no cons or .).
+ */
+ copy->push_front(cenv.penv.sym(cenv.liftStack.top().implName));
+ } else if (head()->to<AFn*>()) {
+ /* Special case: ((fn ...) ...)
+ * Lifting (fn ...) yields: (cons _impl ...).
+ * We don't want ((cons _impl ...) (cons _impl ...) ...),
+ * so call the implementation function (_impl) directly:
+ * (_impl (cons _impl ...) ...)
+ */
+ ACons* closure = (*copy->begin())->as<ACons*>();
+ ASymbol* implSym = (*(closure->begin() + 1))->as<ASymbol*>();
+ const AType* implT = cenv.type(cenv.resolve(implSym));
+ copy->push_front(implSym);
+ cenv.setType(copy, (*(implT->begin() + 2))->as<const AType*>());
+ } else {
+ // Call to a closure, prepend code to access implementation function
+ ADot* getFn = tup<ADot>(loc, cenv.penv.sym("."),
+ copy->head(),
+ new ALiteral<int32_t>(0, Cursor()), NULL);
+ const AType* calleeT = cenv.type(copy->head());
+ assert(**calleeT->begin() == *cenv.tenv.Tup);
+ const AType* implT = (*(calleeT->begin() + 1))->as<const AType*>();
+ copy->push_front(getFn);
+ cenv.setType(getFn, implT);
+ cenv.setType(copy, (*(implT->begin() + 2))->as<const AType*>());
+ }
+
+ return copy;
}
-void
-ADef::lift(CEnv& cenv) throw()
+AST*
+ADef::lift(CEnv& cenv, Code& code) throw()
{
// Define stub first for recursion
cenv.def(sym(), body(), cenv.type(body()), NULL);
AFn* c = body()->to<AFn*>();
if (c)
c->name = sym()->str();
- body()->lift(cenv);
+
+ ADef* copy = new ADef(ATuple::lift(cenv, code)->as<ATuple*>());
+
+ if (copy->sym() == copy->body())
+ return NULL; // Definition created by AFn::lift when body was lifted
+
+ cenv.def(copy->sym(), copy->body(), cenv.type(copy->body()), NULL);
+ cenv.setTypeSameAs(copy, this);
+ return copy;
+}
+
+template<typename T>
+AST*
+lift_builtin_call(CEnv& cenv, T* call, Code& code) throw()
+{
+ ATuple* copy = new T(call);
+ ATuple::iterator ri = copy->begin() + 1;
+
+ // Lift all arguments
+ for (typename T::const_iterator i = call->begin() + 1; i != call->end(); ++i)
+ *ri++ = (*i)->lift(cenv, code);
+
+ cenv.setTypeSameAs(copy, call);
+ return copy;
}
+
+AST* AIf::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); }
+AST* ACons::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); }
+AST* ADot::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); }
+AST* APrimitive::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); }
diff --git a/src/llvm.cpp b/src/llvm.cpp
index 5b6d69f..6243919 100644
--- a/src/llvm.cpp
+++ b/src/llvm.cpp
@@ -123,7 +123,7 @@ struct LLVMEngine : public Engine {
return PointerType::get(StructType::get(context, ctypes, false), 0);
}
- return NULL; // non-primitive type
+ return PointerType::get(Type::getVoidTy(context), NULL);
}
CFunc startFunction(CEnv& cenv,
@@ -161,13 +161,12 @@ struct LLVMEngine : public Engine {
return f;
}
- void finishFunction(CEnv& cenv, CFunc f, const AType* retT, CVal ret) {
- if (retT->concrete())
- builder.CreateRet(llVal(ret));
- else
- builder.CreateRetVoid();
-
- verifyFunction(*static_cast<Function*>(f));
+ void finishFunction(CEnv& cenv, CFunc f, CVal ret) {
+ builder.CreateRet(llVal(ret));
+ if (verifyFunction(*static_cast<Function*>(f), llvm::PrintMessageAction)) {
+ module->dump();
+ throw Error(Cursor(), "Broken module");
+ }
if (cenv.args.find("-g") == cenv.args.end())
opt->run(*static_cast<Function*>(f));
}
@@ -177,19 +176,23 @@ struct LLVMEngine : public Engine {
llFunc(f)->eraseFromParent();
}
- CVal compileCall(CEnv& cenv, CFunc f, const vector<CVal>& args) {
- const vector<Value*>& llArgs = *reinterpret_cast<const vector<Value*>*>(&args);
+ CVal compileCall(CEnv& cenv, CFunc f, const AType* funcT, const vector<CVal>& args) {
+ vector<Value*> llArgs(*reinterpret_cast<const vector<Value*>*>(&args));
+ Value* closure = builder.CreateBitCast(llArgs[0],
+ llType(funcT->prot()->head()->as<const AType*>()),
+ cenv.penv.gensymstr("you"));
+ llArgs[0] = closure;
return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end());
}
- CFunc compileFunction(CEnv& cenv, AFn* fn, const AType& argsT);
+ CFunc compileFunction(CEnv& cenv, AFn* fn, const AType* type);
CVal compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields);
CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
CVal compileLiteral(CEnv& cenv, AST* lit);
CVal compilePrimitive(CEnv& cenv, APrimitive* prim);
CVal compileIf(CEnv& cenv, AIf* aif);
- CVal compileGlobal(CEnv& cenv, AType* type, const string& name, CVal val);
+ CVal compileGlobal(CEnv& cenv, const AType* type, const string& name, CVal val);
CVal getGlobal(CEnv& cenv, CVal val);
void writeModule(CEnv& cenv, std::ostream& os) {
@@ -197,7 +200,7 @@ struct LLVMEngine : public Engine {
module->print(os, &writer);
}
- const string call(CEnv& cenv, CFunc f, AType* retT) {
+ const string call(CEnv& cenv, CFunc f, const AType* retT) {
void* fp = engine->getPointerToFunction(llFunc(f));
const Type* t = llType(retT);
THROW_IF(!fp, Cursor(), "unable to get function pointer");
@@ -248,8 +251,9 @@ LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields
{
// Find size of memory required
size_t s = 0;
+ assert(type->begin() != type->end());
for (AType::const_iterator i = type->begin() + 1; i != type->end(); ++i)
- s += llType((*i)->as<AType*>())->getPrimitiveSizeInBits();
+ s += engine->getTargetData()->getTypeSizeInBits(llType((*i)->as<AType*>()));
// Allocate struct
Value* structSize = ConstantInt::get(Type::getInt32Ty(context), bitsToBytes(s));
@@ -259,8 +263,8 @@ LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields
// Set struct fields
size_t i = 0;
for (vector<CVal>::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i) {
- Value* v = builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str());
- builder.CreateStore(llVal(*f), v);
+ builder.CreateStore(llVal(*f),
+ builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str()));
}
return structPtr;
@@ -292,28 +296,13 @@ LLVMEngine::compileLiteral(CEnv& cenv, AST* lit)
}
CFunc
-LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)
+LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType* type)
{
- LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
- AType* genericType = cenv.type(cenv.resolve(fn));
- AType* thisType = genericType;
- Subst argsSubst;
-
- // Build and apply substitution to get concrete type for this call
- if (!genericType->concrete()) {
- argsSubst = cenv.tenv.buildSubst(genericType, argsT);
- thisType = argsSubst.apply(genericType)->as<AType*>();
- }
+ assert(type->concrete());
- THROW_IF(!thisType->concrete(), fn->loc,
- string("call has non-concrete type %1%\n") + thisType->str());
-
- Object::pool.addRoot(thisType);
- Function* f = (Function*)fn->impls.find(thisType);
- if (f)
- return f;
-
- ATuple* protT = thisType->prot();
+ LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
+ const AType* argsT = type->prot()->as<const AType*>();
+ const AType* retT = type->last()->as<const AType*>();
vector<string> argNames;
for (ATuple::const_iterator i = fn->prot()->begin(); i != fn->prot()->end(); ++i)
@@ -321,39 +310,40 @@ LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)
// Write function declaration
const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name;
- f = llFunc(cenv.engine()->startFunction(cenv, name,
- thisType->last()->to<AType*>(),
- *protT, argNames));
+ Function* f = llFunc(cenv.engine()->startFunction(cenv, name, retT, *argsT, argNames));
cenv.push();
- Subst oldSubst = cenv.tsubst;
- cenv.tsubst = Subst::compose(cenv.tsubst, argsSubst);
// Bind argument values in CEnv
vector<Value*> args;
AFn::const_iterator p = fn->prot()->begin();
- ATuple::const_iterator pT = protT->begin();
+ ATuple::const_iterator pT = argsT->begin();
+ assert(fn->prot()->size() == argsT->size());
+ assert(fn->prot()->size() == f->num_args());
for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++pT) {
- AType* t = (*pT)->as<AType*>();
- const Type* lt = llType(t);
+ const AType* t = (*pT)->as<const AType*>();
+ const Type* lt = llType(t);
THROW_IF(!lt, fn->loc, "untyped parameter\n");
cenv.def((*p)->as<ASymbol*>(), *p, t, &*a);
}
+ assert(!cenv.currentFn);
+
// Write function body
try {
- fn->impls.push_back(make_pair(thisType, f));
+ cenv.currentFn = f;
CVal retVal = NULL;
for (AFn::iterator i = fn->begin() + 2; i != fn->end(); ++i)
retVal = (*i)->compile(cenv);
- cenv.engine()->finishFunction(cenv, f, cenv.type(fn->last()), retVal);
+ cenv.engine()->finishFunction(cenv, f, retVal);
} catch (Error& e) {
f->eraseFromParent(); // Error reading body, remove function
cenv.pop();
+ cenv.currentFn = NULL;
throw e;
}
- cenv.tsubst = oldSubst;
cenv.pop();
+ cenv.currentFn = NULL;
return f;
}
@@ -456,12 +446,12 @@ LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim)
}
CVal
-LLVMEngine::compileGlobal(CEnv& cenv, AType* type, const string& name, CVal val)
+LLVMEngine::compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal val)
{
LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
Constant* init = Constant::getNullValue(llType(type));
GlobalVariable* global = new GlobalVariable(*module, llType(type), false,
- GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), name);
+ GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), sym);
engine->builder.CreateStore(llVal(val), global);
return global;
diff --git a/src/repl.cpp b/src/repl.cpp
index 9fa5781..92fb621 100644
--- a/src/repl.cpp
+++ b/src/repl.cpp
@@ -27,17 +27,18 @@
using namespace std;
static bool
-readParseTypeLift(CEnv& cenv, Cursor& cursor, istream& is, AST*& exp, AST*& ast, AType*& type)
+readParseType(CEnv& cenv, Cursor& cursor, istream& is, AST*& exp, AST*& ast)
{
exp = readExpression(cursor, is);
if (exp->to<ATuple*>() && exp->to<ATuple*>()->empty())
return false;
ast = cenv.penv.parse(exp); // Parse input
- Constraints c;
+
+ Constraints c(cenv.tsubst);
ast->constrain(cenv.tenv, c); // Constrain types
- cenv.tsubst = Subst::compose(cenv.tsubst, unify(c)); // Solve type constraints
+ cenv.tsubst = unify(c); // Solve type constraints
// Add types in type substition as GC roots
for (Subst::iterator i = cenv.tsubst.begin(); i != cenv.tsubst.end(); ++i) {
@@ -45,16 +46,11 @@ readParseTypeLift(CEnv& cenv, Cursor& cursor, istream& is, AST*& exp, AST*& ast,
Object::pool.addRoot(i->second);
}
- type = cenv.type(ast);
- THROW_IF(!type, cursor, "call to untyped body");
-
- ast->lift(cenv); // Lift functions
-
return true;
}
static void
-callPrintCollect(CEnv& cenv, CFunc f, AST* result, AType* resultT, bool execute)
+callPrintCollect(CEnv& cenv, CFunc f, AST* result, const AType* resultT, bool execute)
{
if (execute)
cenv.out << cenv.engine()->call(cenv, f, resultT);
@@ -64,43 +60,86 @@ callPrintCollect(CEnv& cenv, CFunc f, AST* result, AType* resultT, bool execute)
cenv.out << " : " << resultT << endl;
Object::pool.collect(Object::pool.roots());
-
- if (cenv.args.find("-d") != cenv.args.end())
- cenv.engine()->writeModule(cenv, cenv.out);
}
/// Compile and evaluate code from @a is
int
eval(CEnv& cenv, const string& name, istream& is, bool execute)
{
- AST* exp = NULL;
- AST* ast = NULL;
- AType* type = NULL;
- list< pair<AST*, AST*> > exprs;
+ AST* exp = NULL;
+ AST* ast = NULL;
+ list<AST*> parsed;
Cursor cursor(name);
try {
- while (readParseTypeLift(cenv, cursor, is, exp, ast, type))
- exprs.push_back(make_pair(exp, ast));
+ while (readParseType(cenv, cursor, is, exp, ast))
+ parsed.push_back(ast);
- //for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i)
+ //for (list< pair<SExp, AST*> >::const_iterator i = parsed.begin(); i != parsed.end(); ++i)
// pprint(cout, i->second->cps(cenv.tenv, cenv.penv.sym("cont")));
CVal val = NULL;
CFunc f = NULL;
- if (type->concrete()) {
- // Create function for top-level of program
- f = cenv.engine()->startFunction(cenv, "main", type, ATuple(cursor));
- // Compile all expressions into it
- for (list< pair<AST*, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i)
- val = i->second->compile(cenv);
+ /*
+ // De-poly all expressions
+ Code concrete;
+ for (list<AST*>::iterator i = parsed.begin(); i != parsed.end(); ++i) {
+ AST* c = (*i)->depoly(cenv, concrete);
+ if (c)
+ concrete.push_back(c);
+ }
- // Finish compilation
- cenv.engine()->finishFunction(cenv, f, type, val);
+ cout << endl << ";;;; CONCRETE {" << endl << endl;
+ for (Code::iterator i = concrete.begin(); i != concrete.end(); ++i)
+ cout << *i << endl << endl;
+ cout << ";;;; } CONCRETE" << endl << endl;*/
+
+ // Lift all expressions
+ Code lifted;
+ for (list<AST*>::iterator i = parsed.begin(); i != parsed.end(); ++i) {
+ AST* l = (*i)->lift(cenv, lifted);
+ if (l)
+ lifted.push_back(l);
+ }
+
+ if (cenv.args.find("-d") != cenv.args.end()) {
+ cout << endl << ";;;; LIFTED {" << endl << endl;
+ for (Code::iterator i = lifted.begin(); i != lifted.end(); ++i) {
+ cout << *i << endl;
+ ADef* def = (*i)->to<ADef*>();
+ if (def)
+ std::cout << " :: " << cenv.type(def->body()) << std::endl;
+ cout << endl;
+ }
+ cout << ";;;; } LIFTED" << endl << endl;
}
+ // Compile top-level (lifted) functions
+ Code exprs;
+ for (Code::iterator i = lifted.begin(); i != lifted.end(); ++i) {
+ ADef* def = (*i)->to<ADef*>();
+ if (def && (*(def->begin() + 2))->to<AFn*>()) {
+ val = def->compile(cenv);
+ } else {
+ exprs.push_back(*i);
+ }
+ }
+
+ const AType* type = cenv.type(exprs.back());
+
+ // Create function for top-level of program
+ f = cenv.engine()->startFunction(cenv, "main", type, ATuple(cursor));
+
+ // Compile expressions (other than function definitions) into it
+ for (list<AST*>::const_iterator i = exprs.begin(); i != exprs.end(); ++i)
+ val = (*i)->compile(cenv);
+
+ // Finish compilation
+ cenv.engine()->finishFunction(cenv, f, val);
+
// Call and print ast
- callPrintCollect(cenv, f, ast, type, execute);
+ if (cenv.args.find("-S") == cenv.args.end())
+ callPrintCollect(cenv, f, ast, type, execute);
} catch (Error& e) {
cenv.err << e.what() << endl;
@@ -113,9 +152,8 @@ eval(CEnv& cenv, const string& name, istream& is, bool execute)
int
repl(CEnv& cenv)
{
- AST* exp = NULL;
- AST* ast = NULL;
- AType* type = NULL;
+ AST* exp = NULL;
+ AST* ast = NULL;
const string replFnName = cenv.penv.gensymstr("_repl");
while (1) {
cenv.out << "() ";
@@ -123,15 +161,20 @@ repl(CEnv& cenv)
Cursor cursor("(stdin)");
try {
- if (!readParseTypeLift(cenv, cursor, std::cin, exp, ast, type))
+ if (!readParseType(cenv, cursor, std::cin, exp, ast))
break;
+ Code lifted;
+ ast = ast->lift(cenv, lifted);
+ const AType* type = cenv.type(ast);
CFunc f = NULL;
try {
// Create function for this repl loop
f = cenv.engine()->startFunction(cenv, replFnName, type, ATuple(cursor));
- cenv.engine()->finishFunction(cenv, f, type, ast->compile(cenv));
+ cenv.engine()->finishFunction(cenv, f, ast->compile(cenv));
callPrintCollect(cenv, f, ast, type, true);
+ if (cenv.args.find("-d") != cenv.args.end())
+ cenv.engine()->writeModule(cenv, cenv.out);
} catch (Error& e) {
cenv.out << e.msg << endl;
cenv.engine()->eraseFunction(cenv, f);
diff --git a/src/resp.cpp b/src/resp.cpp
index 003a76c..89b77ef 100644
--- a/src/resp.cpp
+++ b/src/resp.cpp
@@ -35,14 +35,15 @@ print_usage(char* name, bool error)
os << "Usage: " << name << " [OPTION]... [FILE]..." << endl;
os << "Evaluate and/or compile Resp code" << endl;
os << endl;
- os << " -h Display this help and exit" << endl;
- os << " -r Enter REPL after evaluating files" << endl;
- os << " -p Pretty-print input only" << endl;
- os << " -b BACKEND Backend (llvm or c)" << endl;
- os << " -g Debug (disable optimisation)" << endl;
- os << " -d Dump assembly output" << endl;
- os << " -e EXPRESSION Evaluate EXPRESSION" << endl;
- os << " -o FILE Compile output to FILE (don't run)" << endl;
+ os << " -h Display this help and exit" << endl;
+ os << " -r Enter REPL after evaluating files" << endl;
+ os << " -p Pretty-print input only" << endl;
+ os << " -b BACKEND Use backend (llvm or c)" << endl;
+ os << " -g Debug (disable optimisation)" << endl;
+ os << " -d Dump generated code during compilation" << endl;
+ os << " -S Stop after compilation (output assembler)" << endl;
+ os << " -e EXPRESSION Evaluate EXPRESSION" << endl;
+ os << " -o FILE Compile output to FILE (don't run)" << endl;
return error ? 1 : 0;
}
@@ -60,7 +61,8 @@ main(int argc, char** argv)
} else if (!strncmp(argv[i], "-r", 3)
|| !strncmp(argv[i], "-p", 3)
|| !strncmp(argv[i], "-g", 3)
- || !strncmp(argv[i], "-d", 3)) {
+ || !strncmp(argv[i], "-d", 3)
+ || !strncmp(argv[i], "-S", 3)) {
args.insert(make_pair(argv[i], ""));
} else if (i == argc-1 || argv[i+1][0] == '-') {
return print_usage(argv[0], true);
@@ -131,15 +133,21 @@ main(int argc, char** argv)
if (args.find("-r") != args.end() || (files.empty() && args.find("-e") == args.end()))
ret = repl(*cenv);
- if (output != "") {
- ofstream os(output.c_str());
+ if (cenv->args.find("-S") != cenv->args.end() || cenv->args.find("-d") != cenv->args.end()) {
+ ofstream fs;
+ if (output != "")
+ fs.open(output.c_str());
+
+ ostream& os = (output != "") ? fs : cout;
+
if (os.good()) {
cenv->engine()->writeModule(*cenv, os);
} else {
cerr << argv[0] << ": " << a->second << ": " << strerror(errno) << endl;
++ret;
}
- os.close();
+ if (output != "")
+ fs.close();
}
delete cenv;
diff --git a/src/resp.hpp b/src/resp.hpp
index ade7257..f172548 100644
--- a/src/resp.hpp
+++ b/src/resp.hpp
@@ -30,6 +30,7 @@
#include <map>
#include <set>
#include <sstream>
+#include <stack>
#include <string>
#include <vector>
#include <boost/format.hpp>
@@ -92,6 +93,12 @@ struct Env : public list< vector< pair<K,V> > > {
return true;
return false;
}
+ bool innermost(const K& key) const {
+ for (typename Frame::const_iterator b = this->front().begin(); b != this->front().end(); ++b)
+ if (b->first == key)
+ return true;
+ return false;
+ }
};
template<typename K, typename V>
@@ -187,6 +194,8 @@ struct CEnv; ///< Compile-Time Environment
struct AST;
extern ostream& operator<<(ostream& out, const AST* ast);
+typedef list<AST*> Code;
+
/// Base class for all AST nodes
struct AST : public Object {
AST(Cursor c=Cursor()) : loc(c) {}
@@ -196,7 +205,8 @@ struct AST : public Object {
virtual bool contains(const AST* child) const { return false; }
virtual void constrain(TEnv& tenv, Constraints& c) const throw(Error) {}
virtual AST* cps(TEnv& tenv, AST* cont) const;
- virtual void lift(CEnv& cenv) throw() {}
+ virtual AST* lift(CEnv& cenv, Code& code) throw() { return this; }
+ virtual AST* depoly(CEnv& cenv, Code& code) throw() { return this; }
virtual CVal compile(CEnv& cenv) throw() = 0;
string str() const { ostringstream ss; ss << this; return ss.str(); }
template<typename T> T to() { return dynamic_cast<T>(this); }
@@ -247,6 +257,7 @@ struct AString : public AST, public std::string {
struct ASymbol : public AST {
bool operator==(const AST& rhs) const { return this == &rhs; }
void constrain(TEnv& tenv, Constraints& c) const throw(Error);
+ AST* lift(CEnv& cenv, Code& code) throw();
CVal compile(CEnv& cenv) throw();
const string cppstr;
private:
@@ -273,6 +284,12 @@ struct ATuple : public AST {
newvec[_len++] = ast;
_vec = newvec;
}
+ void push_front(AST* ast) {
+ AST** newvec = (AST**)malloc(sizeof(AST*) * (_len + 1));
+ newvec[0] = ast;
+ memcpy(newvec + 1, _vec, sizeof(AST*) * _len++);
+ _vec = newvec;
+ }
const AST* head() const { assert(_len > 0); return _vec[0]; }
AST* head() { assert(_len > 0); return _vec[0]; }
const AST* last() const { return _vec[_len - 1]; }
@@ -297,7 +314,7 @@ struct ATuple : public AST {
return false;
return true;
}
- bool contains(AST* child) const {
+ bool contains(const AST* child) const {
if (*this == *child) return true;
FOREACHP(const_iterator, p, this)
if (**p == *child || (*p)->contains(child))
@@ -305,7 +322,8 @@ struct ATuple : public AST {
return false;
}
void constrain(TEnv& tenv, Constraints& c) const throw(Error);
- void lift(CEnv& cenv) throw() { FOREACHP(iterator, t, this) (*t)->lift(cenv); }
+ AST* lift(CEnv& cenv, Code& code) throw();
+ AST* depoly(CEnv& cenv, Code& code) throw();
CVal compile(CEnv& cenv) throw() { return NULL; }
@@ -354,63 +372,18 @@ struct AType : public ATuple {
unsigned id;
};
-/// Type substitution
-struct Subst : public list< pair<const AType*,AType*> > {
- Subst(AType* s=0, AType* t=0) { if (s && t) { assert(s != t); push_back(make_pair(s, t)); } }
- static Subst compose(const Subst& delta, const Subst& gamma);
- void add(const AType* from, AType* to) { push_back(make_pair(from, to)); }
- const_iterator find(const AType* t) const {
- for (const_iterator j = begin(); j != end(); ++j)
- if (*j->first == *t)
- return j;
- return end();
- }
- AType* apply(const AType* in) const {
- if (in->kind == AType::EXPR) {
- AType* out = tup<AType>(in->loc, NULL);
- for (ATuple::const_iterator i = in->begin(); i != in->end(); ++i)
- out->push_back(apply((*i)->as<AType*>()));
- return out;
- } else {
- const_iterator i = find(in);
- if (i != end()) {
- AType* out = i->second->as<AType*>();
- if (out->kind == AType::EXPR && !out->concrete())
- out = apply(out->as<AType*>());
- return out;
- } else {
- return new AType(*in);
- }
- }
- }
-};
-
-inline ostream& operator<<(ostream& out, const Subst& s) {
- for (Subst::const_iterator i = s.begin(); i != s.end(); ++i)
- out << i->first << " => " << i->second << endl;
- return out;
-}
-
/// Fn (first-class function with captured lexical bindings)
struct AFn : public ATuple {
+ AFn(const ATuple* exp) : ATuple(*exp) {}
AFn(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args) {}
bool operator==(const AST& rhs) const { return this == &rhs; }
void constrain(TEnv& tenv, Constraints& c) const throw(Error);
AST* cps(TEnv& tenv, AST* cont) const;
- void lift(CEnv& cenv) throw();
+ AST* lift(CEnv& cenv, Code& code) throw();
+ AST* depoly(CEnv& cenv, Code& code) throw();
CVal compile(CEnv& cenv) throw();
const ATuple* prot() const { return (*(begin() + 1))->to<const ATuple*>(); }
ATuple* prot() { return (*(begin() + 1))->to<ATuple*>(); }
- /// System level implementations of this (polymorphic) fn
- struct Impls : public list< pair<AType*, CFunc> > {
- CFunc find(AType* type) const {
- for (const_iterator f = begin(); f != end(); ++f)
- if (*f->first == *type)
- return f->second;
- return NULL;
- }
- };
- Impls impls;
string name;
};
@@ -420,7 +393,8 @@ struct ACall : public ATuple {
ACall(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args) {}
void constrain(TEnv& tenv, Constraints& c) const throw(Error);
AST* cps(TEnv& tenv, AST* cont) const;
- void lift(CEnv& cenv) throw();
+ AST* lift(CEnv& cenv, Code& code) throw();
+ AST* depoly(CEnv& cenv, Code& code) throw();
CVal compile(CEnv& cenv) throw();
};
@@ -442,7 +416,8 @@ struct ADef : public ACall {
AST* body() { return *(begin() + 2); }
void constrain(TEnv& tenv, Constraints& c) const throw(Error);
AST* cps(TEnv& tenv, AST* cont) const;
- void lift(CEnv& cenv) throw();
+ AST* lift(CEnv& cenv, Code& code) throw();
+ AST* depoly(CEnv& cenv, Code& code) throw();
CVal compile(CEnv& cenv) throw();
};
@@ -452,19 +427,26 @@ struct AIf : public ACall {
AIf(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {}
void constrain(TEnv& tenv, Constraints& c) const throw(Error);
AST* cps(TEnv& tenv, AST* cont) const;
+ AST* lift(CEnv& cenv, Code& code) throw();
+ AST* depoly(CEnv& cenv, Code& code) throw();
CVal compile(CEnv& cenv) throw();
};
struct ACons : public ACall {
ACons(const ATuple* exp) : ACall(exp) {}
+ ACons(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {}
void constrain(TEnv& tenv, Constraints& c) const throw(Error);
+ AST* lift(CEnv& cenv, Code& code) throw();
+ AST* depoly(CEnv& cenv, Code& code) throw();
CVal compile(CEnv& cenv) throw();
};
struct ADot : public ACall {
ADot(const ATuple* exp) : ACall(exp) {}
+ ADot(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {}
void constrain(TEnv& tenv, Constraints& c) const throw(Error);
- void lift(CEnv& cenv) throw();
+ AST* lift(CEnv& cenv, Code& code) throw();
+ AST* depoly(CEnv& cenv, Code& code) throw();
CVal compile(CEnv& cenv) throw();
};
@@ -480,6 +462,8 @@ struct APrimitive : public ACall {
}
void constrain(TEnv& tenv, Constraints& c) const throw(Error);
AST* cps(TEnv& tenv, AST* cont) const;
+ AST* lift(CEnv& cenv, Code& code) throw();
+ AST* depoly(CEnv& cenv, Code& code) throw();
CVal compile(CEnv& cenv) throw();
};
@@ -512,7 +496,7 @@ struct PEnv : private map<const string, ASymbol*> {
map<string, MF>::const_iterator i = macros.find(s);
return (i != macros.end()) ? i->second : NULL;
}
- string gensymstr(const char* s="_") { return (format("%s%d") % s % symID++).str(); }
+ string gensymstr(const char* s="_") { return (format("%s_%d") % s % symID++).str(); }
ASymbol* gensym(const char* s="_") { return sym(gensymstr(s)); }
ASymbol* sym(const string& s, Cursor c=Cursor()) {
const const_iterator i = find(s);
@@ -571,17 +555,62 @@ struct PEnv : private map<const string, ASymbol*> {
***************************************************************************/
/// Type constraint
-struct Constraint : public pair<AType*,AType*> {
- Constraint(AType* a, AType* b, Cursor c) : pair<AType*,AType*>(a, b), loc(c) {}
+struct Constraint : public pair<const AType*,const AType*> {
+ Constraint(const AType* a, const AType* b, Cursor c)
+ : pair<const AType*, const AType*>(a, b), loc(c) {}
Cursor loc;
};
+/// Type substitution
+struct Subst : public list<Constraint> {
+ Subst(const AType* s=0, const AType* t=0) {
+ if (s && t) { assert(s != t); push_back(Constraint(s, t, t->loc)); }
+ }
+ static Subst compose(const Subst& delta, const Subst& gamma);
+ void add(const AType* from, const AType* to) { push_back(Constraint(from, to, Cursor())); }
+ const_iterator find(const AType* t) const {
+ for (const_iterator j = begin(); j != end(); ++j)
+ if (*j->first == *t)
+ return j;
+ return end();
+ }
+ const AType* apply(const AType* in) const {
+ if (in->kind == AType::EXPR) {
+ AType* out = tup<AType>(in->loc, NULL);
+ for (ATuple::const_iterator i = in->begin(); i != in->end(); ++i)
+ out->push_back(const_cast<AType*>(apply((*i)->as<AType*>())));
+ return out;
+ } else {
+ const_iterator i = find(in);
+ if (i != end()) {
+ const AType* out = i->second->as<const AType*>();
+ if (out->kind == AType::EXPR && !out->concrete())
+ out = const_cast<AType*>(apply(out->as<const AType*>()));
+ return out;
+ } else {
+ return new AType(*in);
+ }
+ }
+ }
+};
+
+inline ostream& operator<<(ostream& out, const Subst& s) {
+ for (Subst::const_iterator i = s.begin(); i != s.end(); ++i)
+ out << i->first << " => " << i->second << endl;
+ return out;
+}
+
/// Type constraint set
struct Constraints : public list<Constraint> {
Constraints() : list<Constraint>() {}
+ Constraints(const Subst& subst) : list<Constraint>() {
+ FOREACH(Subst::const_iterator, i, subst) {
+ push_back(Constraint(new AType(*i->first), new AType(*i->second), Cursor()));
+ }
+ }
Constraints(const_iterator begin, const_iterator end) : list<Constraint>(begin, end) {}
- void constrain(TEnv& tenv, const AST* o, AType* t);
- Constraints& replace(AType* s, AType* t);
+ void constrain(TEnv& tenv, const AST* o, const AType* t);
+ Constraints replace(const AType* s, const AType* t);
};
inline ostream& operator<<(ostream& out, const Constraints& c) {
@@ -591,7 +620,7 @@ inline ostream& operator<<(ostream& out, const Constraints& c) {
}
/// Type-Time Environment
-struct TEnv : public Env<const ASymbol*, AType*> {
+struct TEnv : public Env<const ASymbol*, const AType*> {
TEnv(PEnv& p)
: penv(p)
, varID(1)
@@ -600,10 +629,10 @@ struct TEnv : public Env<const ASymbol*, AType*> {
{
Object::pool.addRoot(Fn);
}
- AType* fresh(const ASymbol* sym) {
+ const AType* fresh(const ASymbol* sym) {
return def(sym, new AType(sym->loc, varID++));
}
- AType* var(const AST* ast=0) {
+ const AType* var(const AST* ast=0) {
if (!ast)
return new AType(Cursor(), varID++);
@@ -617,12 +646,12 @@ struct TEnv : public Env<const ASymbol*, AType*> {
return (vars[ast] = new AType(ast->loc, varID++));
}
- AType* named(const string& name) {
+ const AType* named(const string& name) {
return *ref(penv.sym(name));
}
- static Subst buildSubst(AType* fnT, const AType& argsT);
+ static Subst buildSubst(const AType* fnT, const AType& argsT);
- typedef map<const AST*, AType*> Vars;
+ typedef map<const AST*, const AType*> Vars;
Vars vars;
PEnv& penv;
@@ -650,20 +679,22 @@ struct Engine {
const ATuple& argsT,
const vector<string> argNames=vector<string>()) = 0;
- virtual void finishFunction(CEnv& cenv, CFunc f, const AType* retT, CVal ret) = 0;
- virtual void eraseFunction(CEnv& cenv, CFunc f) = 0;
- virtual CFunc compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) = 0;
- virtual CVal compileTup(CEnv& cenv, const AType* t, const vector<CVal>& f) = 0;
- virtual CVal compileDot(CEnv& cenv, CVal tup, int32_t index) = 0;
- virtual CVal compileLiteral(CEnv& cenv, AST* lit) = 0;
- virtual CVal compileCall(CEnv& cenv, CFunc f, const vector<CVal>& args) = 0;
- virtual CVal compilePrimitive(CEnv& cenv, APrimitive* prim) = 0;
- virtual CVal compileIf(CEnv& cenv, AIf* aif) = 0;
- virtual CVal compileGlobal(CEnv& cenv, AType* t, const string& sym, CVal val) = 0;
- virtual CVal getGlobal(CEnv& cenv, CVal val) = 0;
- virtual void writeModule(CEnv& cenv, std::ostream& os) = 0;
-
- virtual const string call(CEnv& cenv, CFunc f, AType* retT) = 0;
+ typedef const vector<CVal> ValVec;
+
+ virtual void finishFunction(CEnv& cenv, CFunc f, CVal ret) = 0;
+ virtual void eraseFunction(CEnv& cenv, CFunc f) = 0;
+ virtual CFunc compileFunction(CEnv& cenv, AFn* fn, const AType* type) = 0;
+ virtual CVal compileTup(CEnv& cenv, const AType* t, ValVec& f) = 0;
+ virtual CVal compileDot(CEnv& cenv, CVal tup, int32_t index) = 0;
+ virtual CVal compileLiteral(CEnv& cenv, AST* lit) = 0;
+ virtual CVal compileCall(CEnv& cenv, CFunc f, const AType* fT, ValVec& args) = 0;
+ virtual CVal compilePrimitive(CEnv& cenv, APrimitive* prim) = 0;
+ virtual CVal compileIf(CEnv& cenv, AIf* aif) = 0;
+ virtual CVal compileGlobal(CEnv& cenv, const AType* t, const string& sym, CVal val) = 0;
+ virtual CVal getGlobal(CEnv& cenv, CVal val) = 0;
+ virtual void writeModule(CEnv& cenv, std::ostream& os) = 0;
+
+ virtual const string call(CEnv& cenv, CFunc f, const AType* retT) = 0;
};
Engine* resp_new_llvm_engine();
@@ -683,14 +714,19 @@ struct CEnv {
void push() { code.push(); tenv.push(); vals.push(); }
void pop() { code.pop(); tenv.pop(); vals.pop(); }
void lock(AST* ast) { Object::pool.addRoot(ast); Object::pool.addRoot(type(ast)); }
- AType* type(AST* ast, const Subst& subst = Subst()) const {
+ const AType* type(AST* ast, const Subst& subst = Subst()) const {
ASymbol* sym = ast->to<ASymbol*>();
- if (sym)
- return *tenv.ref(sym);
- assert(tenv.vars[ast]);
- return tsubst.apply(subst.apply(tenv.vars[ast]))->to<AType*>();
+ if (sym) {
+ const AType** rec = tenv.ref(sym);
+ return rec ? *rec : NULL;
+ }
+ const AType* var = tenv.vars[ast];
+ if (var) {
+ return tsubst.apply(subst.apply(var))->to<const AType*>();
+ }
+ return NULL;
}
- void def(const ASymbol* sym, AST* c, AType* t, CVal v) {
+ void def(const ASymbol* sym, AST* c, const AType* t, CVal v) {
code.def(sym, c);
tenv.def(sym, t);
vals.def(sym, v);
@@ -700,6 +736,14 @@ struct CEnv {
AST** rec = code.ref(sym);
return rec ? *rec : ast;
}
+ void setType(AST* ast, const AType* type) {
+ const AType* tvar = tenv.var();
+ tenv.vars.insert(make_pair(ast, tvar));
+ tsubst.add(tvar, type);
+ }
+ void setTypeSameAs(AST* ast, AST* typedAst) {
+ tenv.vars.insert(make_pair(ast, tenv.vars[typedAst]));
+ }
ostream& out;
ostream& err;
@@ -710,8 +754,39 @@ struct CEnv {
Env<const ASymbol*, AST*> code;
+ typedef map<AFn*, CFunc> Impls;
+ Impls impls;
+
+ CFunc findImpl(AFn* fn, const AType* type) {
+ Impls::iterator i = impls.find(fn);
+ return (i != impls.end()) ? i->second : NULL;
+ }
+
+ void addImpl(AFn* fn, CFunc impl) {
+ impls.insert(make_pair(fn, impl));
+ }
+
map<string,string> args;
+ CFunc currentFn; ///< Currently compiling function
+
+ struct FreeVars : public std::vector<ASymbol*> {
+ FreeVars(AFn* f, const std::string& n) : fn(f), implName(n) {}
+ AFn* const fn;
+ const std::string implName;
+ int32_t index(ASymbol* sym) {
+ const_iterator i = find(begin(), end(), sym);
+ if (i != end()) {
+ return i - begin() + 1;
+ } else {
+ push_back(sym);
+ return size();
+ }
+ }
+ };
+ typedef std::stack<FreeVars> LiftStack;
+ LiftStack liftStack;
+
private:
Engine* _engine;
};
diff --git a/src/unify.cpp b/src/unify.cpp
index 93dd78b..1b25861 100644
--- a/src/unify.cpp
+++ b/src/unify.cpp
@@ -26,7 +26,7 @@
* with a specific set of argument types
*/
Subst
-TEnv::buildSubst(AType* genericT, const AType& argsT)
+TEnv::buildSubst(const AType* genericT, const AType& argsT)
{
Subst subst;
@@ -56,7 +56,7 @@ TEnv::buildSubst(AType* genericT, const AType& argsT)
}
void
-Constraints::constrain(TEnv& tenv, const AST* o, AType* t)
+Constraints::constrain(TEnv& tenv, const AST* o, const AType* t)
{
assert(o);
assert(t);
@@ -64,15 +64,27 @@ Constraints::constrain(TEnv& tenv, const AST* o, AType* t)
push_back(Constraint(tenv.var(o), t, o->loc));
}
-static void
-substitute(ATuple* tup, const AST* from, AST* to)
+template<typename T, typename E>
+static const T*
+substitute(const T* tup, const E* from, const E* to)
{
- if (!tup) return;
- FOREACHP(ATuple::iterator, i, tup)
- if (**i == *from)
- *i = to;
- else if (*i != to)
- substitute((*i)->to<ATuple*>(), from, to);
+ if (!tup) return NULL;
+ T* ret = new T(*tup);
+ typename T::iterator ri = ret->begin();
+ FOREACHP(typename T::const_iterator, i, tup) {
+ if (**i == *from) {
+ *ri++ = const_cast<E*>(to);
+ } else if (static_cast<const E*>(*i) != static_cast<const E*>(to)) {
+ const T* subTup = dynamic_cast<const T*>(*i);
+ if (subTup)
+ *ri++ = const_cast<E*>(substitute(subTup, from, to));
+ else
+ *ri++ = *i;
+ } else {
+ ++ri;
+ }
+ }
+ return ret;
}
/// Compose two substitutions (TAPL 22.1.1)
@@ -92,15 +104,16 @@ Subst::compose(const Subst& delta, const Subst& gamma)
}
/// Replace all occurrences of @a s with @a t
-Constraints&
-Constraints::replace(AType* s, AType* t)
+Constraints
+Constraints::replace(const AType* s, const AType* t)
{
+ Constraints cp(*this);
for (Constraints::iterator c = begin(); c != end();) {
Constraints::iterator next = c; ++next;
if (*c->first == *s) c->first = t;
if (*c->second == *s) c->second = t;
- substitute(c->first, s, t);
- substitute(c->second, s, t);
+ c->first = substitute(c->first, s, t);
+ c->second = substitute(c->second, s, t);
c = next;
}
return *this;
@@ -114,8 +127,8 @@ unify(const Constraints& constraints)
return Subst();
Constraints::const_iterator i = constraints.begin();
- AType* s = i->first;
- AType* t = i->second;
+ const AType* s = i->first;
+ const AType* t = i->second;
Constraints cp(++i, constraints.end());
if (*s == *t) {
@@ -125,8 +138,8 @@ unify(const Constraints& constraints)
} else if (t->kind == AType::VAR && !s->contains(t)) {
return Subst::compose(unify(cp.replace(t, s)), Subst(t, s));
} else if (s->kind == AType::EXPR && t->kind == AType::EXPR) {
- AType::iterator si = s->begin() + 1;
- AType::iterator ti = t->begin() + 1;
+ AType::const_iterator si = s->begin() + 1;
+ AType::const_iterator ti = t->begin() + 1;
for (; si != s->end() && ti != t->end(); ++si, ++ti) {
AType* st = (*si)->as<AType*>();
AType* tt = (*ti)->as<AType*>();
diff --git a/test.sh b/test.sh
index bc54159..e800625 100755
--- a/test.sh
+++ b/test.sh
@@ -13,9 +13,13 @@ run() {
fi
}
-run './test/ack.resp' '8189 : Int'
-run './test/def.resp' '3 : Int'
-run './test/fac.resp' '720 : Int'
-run './test/poly.resp' '#t : Bool'
-run './test/nest.resp' '6 : Int'
-run './test/tup.resp' '5 : Int'
+run './test/ack.resp' '8189 : Int'
+run './test/closure.resp' '6 : Int'
+run './test/def.resp' '4 : Int'
+run './test/deffn.resp' '3 : Int'
+run './test/fac.resp' '720 : Int'
+run './test/inlinefn.resp' '2 : Int'
+run './test/nest.resp' '8 : Int'
+run './test/tup.resp' '5 : Int'
+
+#run './test/poly.resp' '#t : Bool'
diff --git a/test/closure.resp b/test/closure.resp
new file mode 100644
index 0000000..fb5a41d
--- /dev/null
+++ b/test/closure.resp
@@ -0,0 +1,5 @@
+(def (multiplier factor) (fn (x) (* (+ x 0) factor)))
+
+(def doubler (multiplier 2))
+
+(doubler 3)
diff --git a/test/def.resp b/test/def.resp
index dd9a0c8..52605b0 100644
--- a/test/def.resp
+++ b/test/def.resp
@@ -1,7 +1,7 @@
(def foo
(fn (x)
- (def y 2)
- (def z 3)
+ (def y x)
+ (def z (+ x 1))
z))
(foo 3)
diff --git a/test/deffn.resp b/test/deffn.resp
new file mode 100644
index 0000000..c413ecd
--- /dev/null
+++ b/test/deffn.resp
@@ -0,0 +1,3 @@
+(def f (fn (x) (+ x 1)))
+
+(f 2)
diff --git a/test/inlinefn.resp b/test/inlinefn.resp
new file mode 100644
index 0000000..2f055bd
--- /dev/null
+++ b/test/inlinefn.resp
@@ -0,0 +1 @@
+((fn (x) (+ x 1)) 1)
diff --git a/test/nest.resp b/test/nest.resp
index 3085737..c15c453 100644
--- a/test/nest.resp
+++ b/test/nest.resp
@@ -1,6 +1,6 @@
(def (f x)
(def (g y)
(* y 2))
- (g x))
+ (g (+ x 1)))
(f 3)