From 8675beae4f7a8415fc2e88451da95dc068719194 Mon Sep 17 00:00:00 2001 From: David Robillard Date: Tue, 13 Apr 2010 02:28:56 +0000 Subject: Restructure as a source translation based compiler. Implement support for closures (via lambda lifting phase). git-svn-id: http://svn.drobilla.net/resp/resp@254 ad02d1e2-f140-0410-9f75-f8b11f17cedd --- src/c.cpp | 64 +++++++------- src/compile.cpp | 30 +++---- src/constrain.cpp | 33 ++++---- src/lift.cpp | 188 ++++++++++++++++++++++++++++++++++------- src/llvm.cpp | 88 +++++++++----------- src/repl.cpp | 109 ++++++++++++++++-------- src/resp.cpp | 32 ++++--- src/resp.hpp | 245 +++++++++++++++++++++++++++++++++++------------------- src/unify.cpp | 49 +++++++---- 9 files changed, 544 insertions(+), 294 deletions(-) (limited to 'src') diff --git a/src/c.cpp b/src/c.cpp index 7dbb6c2..c456132 100644 --- a/src/c.cpp +++ b/src/c.cpp @@ -42,7 +42,9 @@ static inline Function* llFunc(CFunc f) { return static_cast(f); } static const Type* llType(const AType* t) { - if (t->kind == AType::PRIM) { + if (t == NULL) { + return NULL; + } else if (t->kind == AType::PRIM) { if (t->head()->str() == "Nothing") return new string("void"); if (t->head()->str() == "Bool") return new string("bool"); if (t->head()->str() == "Int") return new string("int"); @@ -65,9 +67,20 @@ llType(const AType* t) } *ret += ")"; + return ret; + } else if (t->kind == AType::EXPR && t->head()->str() == "Tup") { + Type* ret = new Type("struct { void* me; "); + for (AType::const_iterator i = t->begin() + 1; i != t->end(); ++i) { + const Type* lt = llType((*i)->to()); + if (!lt) + return NULL; + ret->append("; "); + ret->append(*lt); + } + ret->append("}*"); return ret; } - return NULL; // non-primitive type + return new Type("void*"); } @@ -116,7 +129,7 @@ struct CEngine : public Engine { return f; } - void finishFunction(CEnv& cenv, CFunc f, const AType* retT, CVal ret) { + void finishFunction(CEnv& cenv, CFunc f, CVal ret) { out += "return " + *(Value*)ret + ";\n}\n\n"; } @@ -124,7 +137,7 @@ struct CEngine : public Engine { cenv.err << "C backend does not support JIT (eraseFunction)" << endl; } - CVal compileCall(CEnv& cenv, CFunc func, const vector& args) { + CVal compileCall(CEnv& cenv, CFunc func, const AType* funcT, const vector& args) { Value* varname = new string(cenv.penv.gensymstr("x")); Function* f = llFunc(func); out += (format("const %s %s = %s(") % f->returnType % *varname % f->name).str(); @@ -134,21 +147,21 @@ struct CEngine : public Engine { return varname; } - CFunc compileFunction(CEnv& cenv, AFn* fn, const AType& argsT); + CFunc compileFunction(CEnv& cenv, AFn* fn, const AType* type); CVal compileTup(CEnv& cenv, const AType* type, const vector& fields); CVal compileDot(CEnv& cenv, CVal tup, int32_t index); CVal compileLiteral(CEnv& cenv, AST* lit); CVal compilePrimitive(CEnv& cenv, APrimitive* prim); CVal compileIf(CEnv& cenv, AIf* aif); - CVal compileGlobal(CEnv& cenv, AType* type, const string& name, CVal val); + CVal compileGlobal(CEnv& cenv, const AType* type, const string& name, CVal val); CVal getGlobal(CEnv& cenv, CVal val); void writeModule(CEnv& cenv, std::ostream& os) { os << out; } - const string call(CEnv& cenv, CFunc f, AType* retT) { + const string call(CEnv& cenv, CFunc f, const AType* retT) { cenv.err << "C backend does not support JIT (call)" << endl; return ""; } @@ -185,28 +198,14 @@ CEngine::compileLiteral(CEnv& cenv, AST* lit) } CFunc -CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) +CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType* type) { - CEngine* engine = reinterpret_cast(cenv.engine()); - AType* genericType = cenv.type(fn); - AType* thisType = genericType; - Subst argsSubst; - - // Build and apply substitution to get concrete type for this call - if (!genericType->concrete()) { - argsSubst = cenv.tenv.buildSubst(genericType, argsT); - thisType = argsSubst.apply(genericType)->as(); - } - - THROW_IF(!thisType->concrete(), fn->loc, - string("call has non-concrete type %1%\n") + thisType->str()); - - Object::pool.addRoot(thisType); - CFunc f = fn->impls.find(thisType); - if (f) - return f; + assert(type->concrete()); - ATuple* protT = thisType->prot(); + CEngine* engine = reinterpret_cast(cenv.engine()); + const AType* argsT = type->prot()->as(); + const AType* retT = type->last()->as(); + Subst argsSubst = cenv.tenv.buildSubst(type, *argsT); vector argNames; for (ATuple::const_iterator i = fn->prot()->begin(); i != fn->prot()->end(); ++i) @@ -214,9 +213,7 @@ CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) // Write function declaration const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name; - f = llFunc(cenv.engine()->startFunction(cenv, name, - thisType->last()->to(), - *protT, argNames)); + Function* f = llFunc(cenv.engine()->startFunction(cenv, name, retT, *argsT, argNames)); cenv.push(); Subst oldSubst = cenv.tsubst; @@ -225,7 +222,7 @@ CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) // Bind argument values in CEnv vector args; AFn::const_iterator p = fn->prot()->begin(); - ATuple::const_iterator pT = protT->begin(); + ATuple::const_iterator pT = argsT->begin(); for (; p != fn->prot()->end(); ++p, ++pT) { AType* t = (*pT)->as(); const Type* lt = llType(t); @@ -235,11 +232,10 @@ CEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) // Write function body try { - fn->impls.push_back(make_pair(thisType, f)); CVal retVal = NULL; for (AFn::iterator i = fn->begin() + 2; i != fn->end(); ++i) retVal = (*i)->compile(cenv); - cenv.engine()->finishFunction(cenv, f, cenv.type(fn->last()), retVal); + cenv.engine()->finishFunction(cenv, f, retVal); } catch (Error& e) { cenv.pop(); throw e; @@ -314,7 +310,7 @@ CEngine::compilePrimitive(CEnv& cenv, APrimitive* prim) } CVal -CEngine::compileGlobal(CEnv& cenv, AType* type, const string& name, CVal val) +CEngine::compileGlobal(CEnv& cenv, const AType* type, const string& name, CVal val) { return NULL; } diff --git a/src/compile.cpp b/src/compile.cpp index 4f24994..ee92dbd 100644 --- a/src/compile.cpp +++ b/src/compile.cpp @@ -43,33 +43,29 @@ ASymbol::compile(CEnv& cenv) throw() CVal AFn::compile(CEnv& cenv) throw() { - return impls.find(cenv.type(this)); + const AType* type = cenv.type(this); + CFunc f = cenv.findImpl(this, type); + if (!f) { + f = cenv.engine()->compileFunction(cenv, this, type); + cenv.vals.def(cenv.penv.sym(name), f); + cenv.addImpl(this, f); + } + return f; } CVal ACall::compile(CEnv& cenv) throw() { - AFn* c = cenv.resolve(head())->to(); - - if (!c) return NULL; // Primitive - - AType protT(loc); - for (const_iterator i = begin() + 1; i != end(); ++i) - protT.push_back(cenv.type(*i)); - - AType fnT(loc); - fnT.push_back(cenv.tenv.Fn); - fnT.push_back(&protT); - fnT.push_back(cenv.type(this)); + CFunc f = (*begin())->compile(cenv); - CFunc f = c->impls.find(&fnT); - THROW_IF(!f, loc, (format("callee failed to compile for type %1%") % fnT.str()).str()); + if (!f) + f = cenv.currentFn; // Recursive call (callee defined as a stub) vector args; for (const_iterator e = begin() + 1; e != end(); ++e) args.push_back((*e)->compile(cenv)); - return cenv.engine()->compileCall(cenv, f, args); + return cenv.engine()->compileCall(cenv, f, cenv.type(head()), args); } CVal @@ -83,7 +79,7 @@ ADef::compile(CEnv& cenv) throw() cenv.lock(this); } cenv.vals.def(sym(), val); - return val; + return NULL; } CVal diff --git a/src/constrain.cpp b/src/constrain.cpp index c815bac..024dce4 100644 --- a/src/constrain.cpp +++ b/src/constrain.cpp @@ -42,7 +42,7 @@ AString::constrain(TEnv& tenv, Constraints& c) const throw(Error) void ASymbol::constrain(TEnv& tenv, Constraints& c) const throw(Error) { - AType** ref = tenv.ref(this); + const AType** ref = tenv.ref(this); THROW_IF(!ref, loc, (format("undefined symbol `%1%'") % cppstr).str()); c.constrain(tenv, this, *ref); } @@ -53,7 +53,7 @@ ATuple::constrain(TEnv& tenv, Constraints& c) const throw(Error) AType* t = tup(loc, NULL); FOREACHP(const_iterator, p, this) { (*p)->constrain(tenv, c); - t->push_back(tenv.var(*p)); + t->push_back(const_cast(tenv.var(*p))); } c.constrain(tenv, this, t); } @@ -72,9 +72,9 @@ AFn::constrain(TEnv& tenv, Constraints& c) const throw(Error) THROW_IF(defs.count(sym) != 0, sym->loc, (format("duplicate parameter `%1%'") % sym->str()).str()); defs.insert(sym); - AType* tvar = tenv.fresh(sym); + const AType* tvar = tenv.fresh(sym); frame.push_back(make_pair(sym, tvar)); - protT->push_back(tvar); + protT->push_back(const_cast(tvar)); } const_iterator i = begin() + 1; @@ -95,13 +95,12 @@ AFn::constrain(TEnv& tenv, Constraints& c) const throw(Error) tenv.push(frame); - c.constrain(tenv, this, tenv.var()); AST* exp = NULL; for (i = begin() + 2; i != end(); ++i) (exp = *i)->constrain(tenv, c); - AType* bodyT = tenv.var(exp); - AType* fnT = tup(loc, tenv.Fn, protT, bodyT, 0); + const AType* bodyT = tenv.var(exp); + const AType* fnT = tup(loc, tenv.Fn, protT, bodyT, 0); Object::pool.addRoot(fnT); tenv.pop(); @@ -127,10 +126,10 @@ ACall::constrain(TEnv& tenv, Constraints& c) const throw(Error) (format("expected %1% arguments, got %2%") % numArgs % (size() - 1)).str()); } - AType* retT = tenv.var(); - AType* argsT = tup(loc, 0); + const AType* retT = tenv.var(); + AType* argsT = tup(loc, 0); for (const_iterator i = begin() + 1; i != end(); ++i) - argsT->push_back(tenv.var(*i)); + argsT->push_back(const_cast(tenv.var(*i))); c.constrain(tenv, head(), tup(head()->loc, tenv.Fn, argsT, retT, 0)); c.constrain(tenv, this, retT); @@ -143,7 +142,7 @@ ADef::constrain(TEnv& tenv, Constraints& c) const throw(Error) const ASymbol* sym = this->sym(); THROW_IF(!sym, loc, "`def' has no symbol") - AType* tvar = tenv.var(body()); + const AType* tvar = tenv.var(body()); tenv.def(sym, tvar); body()->constrain(tenv, c); c.constrain(tenv, sym, tvar); @@ -157,7 +156,7 @@ AIf::constrain(TEnv& tenv, Constraints& c) const throw(Error) THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause") for (const_iterator i = begin() + 1; i != end(); ++i) (*i)->constrain(tenv, c); - AType* retT = tenv.var(this); + const AType* retT = tenv.var(this); for (const_iterator i = begin() + 1; true; ++i) { const_iterator next = i; ++next; @@ -178,7 +177,7 @@ ACons::constrain(TEnv& tenv, Constraints& c) const throw(Error) AType* type = tup(loc, tenv.Tup, 0); for (const_iterator i = begin() + 1; i != end(); ++i) { (*i)->constrain(tenv, c); - type->push_back(tenv.var(*i)); + type->push_back(const_cast(tenv.var(*i))); } c.constrain(tenv, this, type); @@ -194,13 +193,13 @@ ADot::constrain(TEnv& tenv, Constraints& c) const throw(Error) THROW_IF(!idx, loc, "the 2nd argument to `.' must be a literal integer"); obj->constrain(tenv, c); - AType* retT = tenv.var(this); + const AType* retT = tenv.var(this); c.constrain(tenv, this, retT); AType* objT = tup(loc, tenv.Tup, 0); for (int i = 0; i < idx->val; ++i) - objT->push_back(tenv.var()); - objT->push_back(retT); + objT->push_back(const_cast(tenv.var())); + objT->push_back(const_cast(retT)); objT->push_back(new AType(obj->loc, AType::DOTS)); c.constrain(tenv, obj, objT); } @@ -228,7 +227,7 @@ APrimitive::constrain(TEnv& tenv, Constraints& c) const throw(Error) i = begin(); - AType* var = NULL; + const AType* var = NULL; switch (type) { case ARITHMETIC: if (size() < 3) diff --git a/src/lift.cpp b/src/lift.cpp index 94f70e0..6a53165 100644 --- a/src/lift.cpp +++ b/src/lift.cpp @@ -17,64 +17,194 @@ /** @file * @brief Lift functions (compilation pass 1) + * After this pass: + * - All function definitions are top-level + * - All references to functions are replaced with references to + * a closure (a tuple with the function and necessary context) */ #include "resp.hpp" using namespace std; -void -AFn::lift(CEnv& cenv) throw() +AST* +ASymbol::lift(CEnv& cenv, Code& code) throw() { + if (!cenv.liftStack.empty() && cppstr == cenv.liftStack.top().fn->name) { + return cenv.penv.sym("me"); // Reference to innermost function + } else if (!cenv.penv.handler(true, cppstr) + && !cenv.penv.handler(false, cppstr) + && !cenv.code.innermost(this)) { + + const int32_t index = cenv.liftStack.top().index(this); + + // Replace symbol with code to access free variable from closure + return tup(loc, cenv.penv.sym("."), + cenv.penv.sym("me"), + new ALiteral(index, Cursor()), + NULL); + } else { + return this; + } +} + +AST* +ATuple::lift(CEnv& cenv, Code& code) throw() +{ + ATuple* ret = new ATuple(*this); + iterator ri = ret->begin(); + FOREACHP(const_iterator, t, this) + *ri++ = (*t)->lift(cenv, code); + cenv.setTypeSameAs(ret, this); + return ret; +} + +AST* +AFn::lift(CEnv& cenv, Code& code) throw() +{ + AFn* impl = new AFn(this); + const string nameBase = cenv.penv.gensymstr(((name != "") ? name : "fn").c_str()); + impl->name = "_" + nameBase; + + cenv.liftStack.push(CEnv::FreeVars(this, impl->name)); + // Create a new stub environment frame for parameters cenv.push(); + AType::const_iterator tp = cenv.type(this)->prot()->begin(); for (const_iterator p = prot()->begin(); p != prot()->end(); ++p) - cenv.def((*p)->as(), *p, NULL, NULL); + cenv.def((*p)->as(), *p, (*tp++)->as(), NULL); + + /* Add closure parameter with dummy name (undefined symbol). + * The name of this parameter will be changed to the name of this + * function after lifting the body (so recursive references correctly + * refer to this function by the closure parameter). + */ + impl->prot()->push_front(cenv.penv.sym("_")); + // Lift body - for (iterator i = begin() + 2; i != end(); ++i) - (*i)->lift(cenv); + const AType* implRetT = NULL; + iterator ci = impl->begin() + 2; + for (const_iterator i = begin() + 2; i != end(); ++i, ++ci) { + *ci = (*i)->lift(cenv, code); + implRetT = cenv.type(*ci); + } cenv.pop(); - AType* type = cenv.type(this); - if (impls.find(type) || !type->concrete()) - return; + // Set name of closure parameter to actual name of this function + *impl->prot()->begin() = cenv.penv.sym("me"); - AType* protT = type->prot()->as(); - cenv.engine()->compileFunction(cenv, this, *protT); -} + // Create definition for implementation fn + ASymbol* implName = cenv.penv.sym(impl->name); + ADef* def = tup(loc, cenv.penv.sym("def"), implName, impl, NULL); + code.push_back(def); -void -ACall::lift(CEnv& cenv) throw() -{ - AFn* c = cenv.resolve(head())->to(); - AType argsT(loc); + AType* implT = new AType(*cenv.type(this)); // Type of the implementation function + AType* tupT = tup(loc, cenv.tenv.Tup, cenv.tenv.var(), NULL); + AType* consT = tup(loc, cenv.tenv.Tup, implT, NULL); + ACons* cons = tup(loc, cenv.penv.sym("cons"), implName, NULL); // Closure - // Lift arguments and build arguments type - for (iterator i = begin() + 1; i != end(); ++i) { - (*i)->lift(cenv); - argsT.push_back(cenv.type(*i)); + const CEnv::FreeVars& freeVars = cenv.liftStack.top(); + for (CEnv::FreeVars::const_iterator i = freeVars.begin(); i != freeVars.end(); ++i) { + cons->push_back(*i); + tupT->push_back(const_cast(cenv.type(*i))); + consT->push_back(const_cast(cenv.type(*i))); } + cenv.liftStack.pop(); - // Lift callee (if it's not a primitive) - if (c) - cenv.engine()->compileFunction(cenv, c, argsT); + implT->prot()->push_front(tupT); + *(implT->begin() + 2) = const_cast(implRetT); + + cenv.setType(impl, implT); + cenv.setType(cons, consT); + + cenv.def(implName, impl, implT, NULL); + if (name != "") + cenv.def(cenv.penv.sym(name), this, consT, NULL); + + return cons; } -void -ADot::lift(CEnv& cenv) throw() +AST* +ACall::lift(CEnv& cenv, Code& code) throw() { - (*(begin() + 1))->lift(cenv); + ACall* copy = new ACall(this); + ATuple::iterator ri = copy->begin(); + + // Lift all children (callee and arguments, recursively) + for (const_iterator i = begin(); i != end(); ++i) + *ri++ = (*i)->lift(cenv, code); + + ASymbol* sym = head()->to(); + if (sym && !cenv.liftStack.empty() && sym->cppstr == cenv.liftStack.top().fn->name) { + /* Recursive call to innermost function, call implementation directly, + * reusing the current "me" closure parameter (no cons or .). + */ + copy->push_front(cenv.penv.sym(cenv.liftStack.top().implName)); + } else if (head()->to()) { + /* Special case: ((fn ...) ...) + * Lifting (fn ...) yields: (cons _impl ...). + * We don't want ((cons _impl ...) (cons _impl ...) ...), + * so call the implementation function (_impl) directly: + * (_impl (cons _impl ...) ...) + */ + ACons* closure = (*copy->begin())->as(); + ASymbol* implSym = (*(closure->begin() + 1))->as(); + const AType* implT = cenv.type(cenv.resolve(implSym)); + copy->push_front(implSym); + cenv.setType(copy, (*(implT->begin() + 2))->as()); + } else { + // Call to a closure, prepend code to access implementation function + ADot* getFn = tup(loc, cenv.penv.sym("."), + copy->head(), + new ALiteral(0, Cursor()), NULL); + const AType* calleeT = cenv.type(copy->head()); + assert(**calleeT->begin() == *cenv.tenv.Tup); + const AType* implT = (*(calleeT->begin() + 1))->as(); + copy->push_front(getFn); + cenv.setType(getFn, implT); + cenv.setType(copy, (*(implT->begin() + 2))->as()); + } + + return copy; } -void -ADef::lift(CEnv& cenv) throw() +AST* +ADef::lift(CEnv& cenv, Code& code) throw() { // Define stub first for recursion cenv.def(sym(), body(), cenv.type(body()), NULL); AFn* c = body()->to(); if (c) c->name = sym()->str(); - body()->lift(cenv); + + ADef* copy = new ADef(ATuple::lift(cenv, code)->as()); + + if (copy->sym() == copy->body()) + return NULL; // Definition created by AFn::lift when body was lifted + + cenv.def(copy->sym(), copy->body(), cenv.type(copy->body()), NULL); + cenv.setTypeSameAs(copy, this); + return copy; +} + +template +AST* +lift_builtin_call(CEnv& cenv, T* call, Code& code) throw() +{ + ATuple* copy = new T(call); + ATuple::iterator ri = copy->begin() + 1; + + // Lift all arguments + for (typename T::const_iterator i = call->begin() + 1; i != call->end(); ++i) + *ri++ = (*i)->lift(cenv, code); + + cenv.setTypeSameAs(copy, call); + return copy; } + +AST* AIf::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); } +AST* ACons::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); } +AST* ADot::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); } +AST* APrimitive::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); } diff --git a/src/llvm.cpp b/src/llvm.cpp index 5b6d69f..6243919 100644 --- a/src/llvm.cpp +++ b/src/llvm.cpp @@ -123,7 +123,7 @@ struct LLVMEngine : public Engine { return PointerType::get(StructType::get(context, ctypes, false), 0); } - return NULL; // non-primitive type + return PointerType::get(Type::getVoidTy(context), NULL); } CFunc startFunction(CEnv& cenv, @@ -161,13 +161,12 @@ struct LLVMEngine : public Engine { return f; } - void finishFunction(CEnv& cenv, CFunc f, const AType* retT, CVal ret) { - if (retT->concrete()) - builder.CreateRet(llVal(ret)); - else - builder.CreateRetVoid(); - - verifyFunction(*static_cast(f)); + void finishFunction(CEnv& cenv, CFunc f, CVal ret) { + builder.CreateRet(llVal(ret)); + if (verifyFunction(*static_cast(f), llvm::PrintMessageAction)) { + module->dump(); + throw Error(Cursor(), "Broken module"); + } if (cenv.args.find("-g") == cenv.args.end()) opt->run(*static_cast(f)); } @@ -177,19 +176,23 @@ struct LLVMEngine : public Engine { llFunc(f)->eraseFromParent(); } - CVal compileCall(CEnv& cenv, CFunc f, const vector& args) { - const vector& llArgs = *reinterpret_cast*>(&args); + CVal compileCall(CEnv& cenv, CFunc f, const AType* funcT, const vector& args) { + vector llArgs(*reinterpret_cast*>(&args)); + Value* closure = builder.CreateBitCast(llArgs[0], + llType(funcT->prot()->head()->as()), + cenv.penv.gensymstr("you")); + llArgs[0] = closure; return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end()); } - CFunc compileFunction(CEnv& cenv, AFn* fn, const AType& argsT); + CFunc compileFunction(CEnv& cenv, AFn* fn, const AType* type); CVal compileTup(CEnv& cenv, const AType* type, const vector& fields); CVal compileDot(CEnv& cenv, CVal tup, int32_t index); CVal compileLiteral(CEnv& cenv, AST* lit); CVal compilePrimitive(CEnv& cenv, APrimitive* prim); CVal compileIf(CEnv& cenv, AIf* aif); - CVal compileGlobal(CEnv& cenv, AType* type, const string& name, CVal val); + CVal compileGlobal(CEnv& cenv, const AType* type, const string& name, CVal val); CVal getGlobal(CEnv& cenv, CVal val); void writeModule(CEnv& cenv, std::ostream& os) { @@ -197,7 +200,7 @@ struct LLVMEngine : public Engine { module->print(os, &writer); } - const string call(CEnv& cenv, CFunc f, AType* retT) { + const string call(CEnv& cenv, CFunc f, const AType* retT) { void* fp = engine->getPointerToFunction(llFunc(f)); const Type* t = llType(retT); THROW_IF(!fp, Cursor(), "unable to get function pointer"); @@ -248,8 +251,9 @@ LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector& fields { // Find size of memory required size_t s = 0; + assert(type->begin() != type->end()); for (AType::const_iterator i = type->begin() + 1; i != type->end(); ++i) - s += llType((*i)->as())->getPrimitiveSizeInBits(); + s += engine->getTargetData()->getTypeSizeInBits(llType((*i)->as())); // Allocate struct Value* structSize = ConstantInt::get(Type::getInt32Ty(context), bitsToBytes(s)); @@ -259,8 +263,8 @@ LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector& fields // Set struct fields size_t i = 0; for (vector::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i) { - Value* v = builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str()); - builder.CreateStore(llVal(*f), v); + builder.CreateStore(llVal(*f), + builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str())); } return structPtr; @@ -292,28 +296,13 @@ LLVMEngine::compileLiteral(CEnv& cenv, AST* lit) } CFunc -LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) +LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType* type) { - LLVMEngine* engine = reinterpret_cast(cenv.engine()); - AType* genericType = cenv.type(cenv.resolve(fn)); - AType* thisType = genericType; - Subst argsSubst; - - // Build and apply substitution to get concrete type for this call - if (!genericType->concrete()) { - argsSubst = cenv.tenv.buildSubst(genericType, argsT); - thisType = argsSubst.apply(genericType)->as(); - } + assert(type->concrete()); - THROW_IF(!thisType->concrete(), fn->loc, - string("call has non-concrete type %1%\n") + thisType->str()); - - Object::pool.addRoot(thisType); - Function* f = (Function*)fn->impls.find(thisType); - if (f) - return f; - - ATuple* protT = thisType->prot(); + LLVMEngine* engine = reinterpret_cast(cenv.engine()); + const AType* argsT = type->prot()->as(); + const AType* retT = type->last()->as(); vector argNames; for (ATuple::const_iterator i = fn->prot()->begin(); i != fn->prot()->end(); ++i) @@ -321,39 +310,40 @@ LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) // Write function declaration const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name; - f = llFunc(cenv.engine()->startFunction(cenv, name, - thisType->last()->to(), - *protT, argNames)); + Function* f = llFunc(cenv.engine()->startFunction(cenv, name, retT, *argsT, argNames)); cenv.push(); - Subst oldSubst = cenv.tsubst; - cenv.tsubst = Subst::compose(cenv.tsubst, argsSubst); // Bind argument values in CEnv vector args; AFn::const_iterator p = fn->prot()->begin(); - ATuple::const_iterator pT = protT->begin(); + ATuple::const_iterator pT = argsT->begin(); + assert(fn->prot()->size() == argsT->size()); + assert(fn->prot()->size() == f->num_args()); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++pT) { - AType* t = (*pT)->as(); - const Type* lt = llType(t); + const AType* t = (*pT)->as(); + const Type* lt = llType(t); THROW_IF(!lt, fn->loc, "untyped parameter\n"); cenv.def((*p)->as(), *p, t, &*a); } + assert(!cenv.currentFn); + // Write function body try { - fn->impls.push_back(make_pair(thisType, f)); + cenv.currentFn = f; CVal retVal = NULL; for (AFn::iterator i = fn->begin() + 2; i != fn->end(); ++i) retVal = (*i)->compile(cenv); - cenv.engine()->finishFunction(cenv, f, cenv.type(fn->last()), retVal); + cenv.engine()->finishFunction(cenv, f, retVal); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function cenv.pop(); + cenv.currentFn = NULL; throw e; } - cenv.tsubst = oldSubst; cenv.pop(); + cenv.currentFn = NULL; return f; } @@ -456,12 +446,12 @@ LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim) } CVal -LLVMEngine::compileGlobal(CEnv& cenv, AType* type, const string& name, CVal val) +LLVMEngine::compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal val) { LLVMEngine* engine = reinterpret_cast(cenv.engine()); Constant* init = Constant::getNullValue(llType(type)); GlobalVariable* global = new GlobalVariable(*module, llType(type), false, - GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), name); + GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), sym); engine->builder.CreateStore(llVal(val), global); return global; diff --git a/src/repl.cpp b/src/repl.cpp index 9fa5781..92fb621 100644 --- a/src/repl.cpp +++ b/src/repl.cpp @@ -27,17 +27,18 @@ using namespace std; static bool -readParseTypeLift(CEnv& cenv, Cursor& cursor, istream& is, AST*& exp, AST*& ast, AType*& type) +readParseType(CEnv& cenv, Cursor& cursor, istream& is, AST*& exp, AST*& ast) { exp = readExpression(cursor, is); if (exp->to() && exp->to()->empty()) return false; ast = cenv.penv.parse(exp); // Parse input - Constraints c; + + Constraints c(cenv.tsubst); ast->constrain(cenv.tenv, c); // Constrain types - cenv.tsubst = Subst::compose(cenv.tsubst, unify(c)); // Solve type constraints + cenv.tsubst = unify(c); // Solve type constraints // Add types in type substition as GC roots for (Subst::iterator i = cenv.tsubst.begin(); i != cenv.tsubst.end(); ++i) { @@ -45,16 +46,11 @@ readParseTypeLift(CEnv& cenv, Cursor& cursor, istream& is, AST*& exp, AST*& ast, Object::pool.addRoot(i->second); } - type = cenv.type(ast); - THROW_IF(!type, cursor, "call to untyped body"); - - ast->lift(cenv); // Lift functions - return true; } static void -callPrintCollect(CEnv& cenv, CFunc f, AST* result, AType* resultT, bool execute) +callPrintCollect(CEnv& cenv, CFunc f, AST* result, const AType* resultT, bool execute) { if (execute) cenv.out << cenv.engine()->call(cenv, f, resultT); @@ -64,43 +60,86 @@ callPrintCollect(CEnv& cenv, CFunc f, AST* result, AType* resultT, bool execute) cenv.out << " : " << resultT << endl; Object::pool.collect(Object::pool.roots()); - - if (cenv.args.find("-d") != cenv.args.end()) - cenv.engine()->writeModule(cenv, cenv.out); } /// Compile and evaluate code from @a is int eval(CEnv& cenv, const string& name, istream& is, bool execute) { - AST* exp = NULL; - AST* ast = NULL; - AType* type = NULL; - list< pair > exprs; + AST* exp = NULL; + AST* ast = NULL; + list parsed; Cursor cursor(name); try { - while (readParseTypeLift(cenv, cursor, is, exp, ast, type)) - exprs.push_back(make_pair(exp, ast)); + while (readParseType(cenv, cursor, is, exp, ast)) + parsed.push_back(ast); - //for (list< pair >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) + //for (list< pair >::const_iterator i = parsed.begin(); i != parsed.end(); ++i) // pprint(cout, i->second->cps(cenv.tenv, cenv.penv.sym("cont"))); CVal val = NULL; CFunc f = NULL; - if (type->concrete()) { - // Create function for top-level of program - f = cenv.engine()->startFunction(cenv, "main", type, ATuple(cursor)); - // Compile all expressions into it - for (list< pair >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) - val = i->second->compile(cenv); + /* + // De-poly all expressions + Code concrete; + for (list::iterator i = parsed.begin(); i != parsed.end(); ++i) { + AST* c = (*i)->depoly(cenv, concrete); + if (c) + concrete.push_back(c); + } - // Finish compilation - cenv.engine()->finishFunction(cenv, f, type, val); + cout << endl << ";;;; CONCRETE {" << endl << endl; + for (Code::iterator i = concrete.begin(); i != concrete.end(); ++i) + cout << *i << endl << endl; + cout << ";;;; } CONCRETE" << endl << endl;*/ + + // Lift all expressions + Code lifted; + for (list::iterator i = parsed.begin(); i != parsed.end(); ++i) { + AST* l = (*i)->lift(cenv, lifted); + if (l) + lifted.push_back(l); + } + + if (cenv.args.find("-d") != cenv.args.end()) { + cout << endl << ";;;; LIFTED {" << endl << endl; + for (Code::iterator i = lifted.begin(); i != lifted.end(); ++i) { + cout << *i << endl; + ADef* def = (*i)->to(); + if (def) + std::cout << " :: " << cenv.type(def->body()) << std::endl; + cout << endl; + } + cout << ";;;; } LIFTED" << endl << endl; } + // Compile top-level (lifted) functions + Code exprs; + for (Code::iterator i = lifted.begin(); i != lifted.end(); ++i) { + ADef* def = (*i)->to(); + if (def && (*(def->begin() + 2))->to()) { + val = def->compile(cenv); + } else { + exprs.push_back(*i); + } + } + + const AType* type = cenv.type(exprs.back()); + + // Create function for top-level of program + f = cenv.engine()->startFunction(cenv, "main", type, ATuple(cursor)); + + // Compile expressions (other than function definitions) into it + for (list::const_iterator i = exprs.begin(); i != exprs.end(); ++i) + val = (*i)->compile(cenv); + + // Finish compilation + cenv.engine()->finishFunction(cenv, f, val); + // Call and print ast - callPrintCollect(cenv, f, ast, type, execute); + if (cenv.args.find("-S") == cenv.args.end()) + callPrintCollect(cenv, f, ast, type, execute); } catch (Error& e) { cenv.err << e.what() << endl; @@ -113,9 +152,8 @@ eval(CEnv& cenv, const string& name, istream& is, bool execute) int repl(CEnv& cenv) { - AST* exp = NULL; - AST* ast = NULL; - AType* type = NULL; + AST* exp = NULL; + AST* ast = NULL; const string replFnName = cenv.penv.gensymstr("_repl"); while (1) { cenv.out << "() "; @@ -123,15 +161,20 @@ repl(CEnv& cenv) Cursor cursor("(stdin)"); try { - if (!readParseTypeLift(cenv, cursor, std::cin, exp, ast, type)) + if (!readParseType(cenv, cursor, std::cin, exp, ast)) break; + Code lifted; + ast = ast->lift(cenv, lifted); + const AType* type = cenv.type(ast); CFunc f = NULL; try { // Create function for this repl loop f = cenv.engine()->startFunction(cenv, replFnName, type, ATuple(cursor)); - cenv.engine()->finishFunction(cenv, f, type, ast->compile(cenv)); + cenv.engine()->finishFunction(cenv, f, ast->compile(cenv)); callPrintCollect(cenv, f, ast, type, true); + if (cenv.args.find("-d") != cenv.args.end()) + cenv.engine()->writeModule(cenv, cenv.out); } catch (Error& e) { cenv.out << e.msg << endl; cenv.engine()->eraseFunction(cenv, f); diff --git a/src/resp.cpp b/src/resp.cpp index 003a76c..89b77ef 100644 --- a/src/resp.cpp +++ b/src/resp.cpp @@ -35,14 +35,15 @@ print_usage(char* name, bool error) os << "Usage: " << name << " [OPTION]... [FILE]..." << endl; os << "Evaluate and/or compile Resp code" << endl; os << endl; - os << " -h Display this help and exit" << endl; - os << " -r Enter REPL after evaluating files" << endl; - os << " -p Pretty-print input only" << endl; - os << " -b BACKEND Backend (llvm or c)" << endl; - os << " -g Debug (disable optimisation)" << endl; - os << " -d Dump assembly output" << endl; - os << " -e EXPRESSION Evaluate EXPRESSION" << endl; - os << " -o FILE Compile output to FILE (don't run)" << endl; + os << " -h Display this help and exit" << endl; + os << " -r Enter REPL after evaluating files" << endl; + os << " -p Pretty-print input only" << endl; + os << " -b BACKEND Use backend (llvm or c)" << endl; + os << " -g Debug (disable optimisation)" << endl; + os << " -d Dump generated code during compilation" << endl; + os << " -S Stop after compilation (output assembler)" << endl; + os << " -e EXPRESSION Evaluate EXPRESSION" << endl; + os << " -o FILE Compile output to FILE (don't run)" << endl; return error ? 1 : 0; } @@ -60,7 +61,8 @@ main(int argc, char** argv) } else if (!strncmp(argv[i], "-r", 3) || !strncmp(argv[i], "-p", 3) || !strncmp(argv[i], "-g", 3) - || !strncmp(argv[i], "-d", 3)) { + || !strncmp(argv[i], "-d", 3) + || !strncmp(argv[i], "-S", 3)) { args.insert(make_pair(argv[i], "")); } else if (i == argc-1 || argv[i+1][0] == '-') { return print_usage(argv[0], true); @@ -131,15 +133,21 @@ main(int argc, char** argv) if (args.find("-r") != args.end() || (files.empty() && args.find("-e") == args.end())) ret = repl(*cenv); - if (output != "") { - ofstream os(output.c_str()); + if (cenv->args.find("-S") != cenv->args.end() || cenv->args.find("-d") != cenv->args.end()) { + ofstream fs; + if (output != "") + fs.open(output.c_str()); + + ostream& os = (output != "") ? fs : cout; + if (os.good()) { cenv->engine()->writeModule(*cenv, os); } else { cerr << argv[0] << ": " << a->second << ": " << strerror(errno) << endl; ++ret; } - os.close(); + if (output != "") + fs.close(); } delete cenv; diff --git a/src/resp.hpp b/src/resp.hpp index ade7257..f172548 100644 --- a/src/resp.hpp +++ b/src/resp.hpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -92,6 +93,12 @@ struct Env : public list< vector< pair > > { return true; return false; } + bool innermost(const K& key) const { + for (typename Frame::const_iterator b = this->front().begin(); b != this->front().end(); ++b) + if (b->first == key) + return true; + return false; + } }; template @@ -187,6 +194,8 @@ struct CEnv; ///< Compile-Time Environment struct AST; extern ostream& operator<<(ostream& out, const AST* ast); +typedef list Code; + /// Base class for all AST nodes struct AST : public Object { AST(Cursor c=Cursor()) : loc(c) {} @@ -196,7 +205,8 @@ struct AST : public Object { virtual bool contains(const AST* child) const { return false; } virtual void constrain(TEnv& tenv, Constraints& c) const throw(Error) {} virtual AST* cps(TEnv& tenv, AST* cont) const; - virtual void lift(CEnv& cenv) throw() {} + virtual AST* lift(CEnv& cenv, Code& code) throw() { return this; } + virtual AST* depoly(CEnv& cenv, Code& code) throw() { return this; } virtual CVal compile(CEnv& cenv) throw() = 0; string str() const { ostringstream ss; ss << this; return ss.str(); } template T to() { return dynamic_cast(this); } @@ -247,6 +257,7 @@ struct AString : public AST, public std::string { struct ASymbol : public AST { bool operator==(const AST& rhs) const { return this == &rhs; } void constrain(TEnv& tenv, Constraints& c) const throw(Error); + AST* lift(CEnv& cenv, Code& code) throw(); CVal compile(CEnv& cenv) throw(); const string cppstr; private: @@ -273,6 +284,12 @@ struct ATuple : public AST { newvec[_len++] = ast; _vec = newvec; } + void push_front(AST* ast) { + AST** newvec = (AST**)malloc(sizeof(AST*) * (_len + 1)); + newvec[0] = ast; + memcpy(newvec + 1, _vec, sizeof(AST*) * _len++); + _vec = newvec; + } const AST* head() const { assert(_len > 0); return _vec[0]; } AST* head() { assert(_len > 0); return _vec[0]; } const AST* last() const { return _vec[_len - 1]; } @@ -297,7 +314,7 @@ struct ATuple : public AST { return false; return true; } - bool contains(AST* child) const { + bool contains(const AST* child) const { if (*this == *child) return true; FOREACHP(const_iterator, p, this) if (**p == *child || (*p)->contains(child)) @@ -305,7 +322,8 @@ struct ATuple : public AST { return false; } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - void lift(CEnv& cenv) throw() { FOREACHP(iterator, t, this) (*t)->lift(cenv); } + AST* lift(CEnv& cenv, Code& code) throw(); + AST* depoly(CEnv& cenv, Code& code) throw(); CVal compile(CEnv& cenv) throw() { return NULL; } @@ -354,63 +372,18 @@ struct AType : public ATuple { unsigned id; }; -/// Type substitution -struct Subst : public list< pair > { - Subst(AType* s=0, AType* t=0) { if (s && t) { assert(s != t); push_back(make_pair(s, t)); } } - static Subst compose(const Subst& delta, const Subst& gamma); - void add(const AType* from, AType* to) { push_back(make_pair(from, to)); } - const_iterator find(const AType* t) const { - for (const_iterator j = begin(); j != end(); ++j) - if (*j->first == *t) - return j; - return end(); - } - AType* apply(const AType* in) const { - if (in->kind == AType::EXPR) { - AType* out = tup(in->loc, NULL); - for (ATuple::const_iterator i = in->begin(); i != in->end(); ++i) - out->push_back(apply((*i)->as())); - return out; - } else { - const_iterator i = find(in); - if (i != end()) { - AType* out = i->second->as(); - if (out->kind == AType::EXPR && !out->concrete()) - out = apply(out->as()); - return out; - } else { - return new AType(*in); - } - } - } -}; - -inline ostream& operator<<(ostream& out, const Subst& s) { - for (Subst::const_iterator i = s.begin(); i != s.end(); ++i) - out << i->first << " => " << i->second << endl; - return out; -} - /// Fn (first-class function with captured lexical bindings) struct AFn : public ATuple { + AFn(const ATuple* exp) : ATuple(*exp) {} AFn(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args) {} bool operator==(const AST& rhs) const { return this == &rhs; } void constrain(TEnv& tenv, Constraints& c) const throw(Error); AST* cps(TEnv& tenv, AST* cont) const; - void lift(CEnv& cenv) throw(); + AST* lift(CEnv& cenv, Code& code) throw(); + AST* depoly(CEnv& cenv, Code& code) throw(); CVal compile(CEnv& cenv) throw(); const ATuple* prot() const { return (*(begin() + 1))->to(); } ATuple* prot() { return (*(begin() + 1))->to(); } - /// System level implementations of this (polymorphic) fn - struct Impls : public list< pair > { - CFunc find(AType* type) const { - for (const_iterator f = begin(); f != end(); ++f) - if (*f->first == *type) - return f->second; - return NULL; - } - }; - Impls impls; string name; }; @@ -420,7 +393,8 @@ struct ACall : public ATuple { ACall(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); AST* cps(TEnv& tenv, AST* cont) const; - void lift(CEnv& cenv) throw(); + AST* lift(CEnv& cenv, Code& code) throw(); + AST* depoly(CEnv& cenv, Code& code) throw(); CVal compile(CEnv& cenv) throw(); }; @@ -442,7 +416,8 @@ struct ADef : public ACall { AST* body() { return *(begin() + 2); } void constrain(TEnv& tenv, Constraints& c) const throw(Error); AST* cps(TEnv& tenv, AST* cont) const; - void lift(CEnv& cenv) throw(); + AST* lift(CEnv& cenv, Code& code) throw(); + AST* depoly(CEnv& cenv, Code& code) throw(); CVal compile(CEnv& cenv) throw(); }; @@ -452,19 +427,26 @@ struct AIf : public ACall { AIf(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); AST* cps(TEnv& tenv, AST* cont) const; + AST* lift(CEnv& cenv, Code& code) throw(); + AST* depoly(CEnv& cenv, Code& code) throw(); CVal compile(CEnv& cenv) throw(); }; struct ACons : public ACall { ACons(const ATuple* exp) : ACall(exp) {} + ACons(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); + AST* lift(CEnv& cenv, Code& code) throw(); + AST* depoly(CEnv& cenv, Code& code) throw(); CVal compile(CEnv& cenv) throw(); }; struct ADot : public ACall { ADot(const ATuple* exp) : ACall(exp) {} + ADot(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); - void lift(CEnv& cenv) throw(); + AST* lift(CEnv& cenv, Code& code) throw(); + AST* depoly(CEnv& cenv, Code& code) throw(); CVal compile(CEnv& cenv) throw(); }; @@ -480,6 +462,8 @@ struct APrimitive : public ACall { } void constrain(TEnv& tenv, Constraints& c) const throw(Error); AST* cps(TEnv& tenv, AST* cont) const; + AST* lift(CEnv& cenv, Code& code) throw(); + AST* depoly(CEnv& cenv, Code& code) throw(); CVal compile(CEnv& cenv) throw(); }; @@ -512,7 +496,7 @@ struct PEnv : private map { map::const_iterator i = macros.find(s); return (i != macros.end()) ? i->second : NULL; } - string gensymstr(const char* s="_") { return (format("%s%d") % s % symID++).str(); } + string gensymstr(const char* s="_") { return (format("%s_%d") % s % symID++).str(); } ASymbol* gensym(const char* s="_") { return sym(gensymstr(s)); } ASymbol* sym(const string& s, Cursor c=Cursor()) { const const_iterator i = find(s); @@ -571,17 +555,62 @@ struct PEnv : private map { ***************************************************************************/ /// Type constraint -struct Constraint : public pair { - Constraint(AType* a, AType* b, Cursor c) : pair(a, b), loc(c) {} +struct Constraint : public pair { + Constraint(const AType* a, const AType* b, Cursor c) + : pair(a, b), loc(c) {} Cursor loc; }; +/// Type substitution +struct Subst : public list { + Subst(const AType* s=0, const AType* t=0) { + if (s && t) { assert(s != t); push_back(Constraint(s, t, t->loc)); } + } + static Subst compose(const Subst& delta, const Subst& gamma); + void add(const AType* from, const AType* to) { push_back(Constraint(from, to, Cursor())); } + const_iterator find(const AType* t) const { + for (const_iterator j = begin(); j != end(); ++j) + if (*j->first == *t) + return j; + return end(); + } + const AType* apply(const AType* in) const { + if (in->kind == AType::EXPR) { + AType* out = tup(in->loc, NULL); + for (ATuple::const_iterator i = in->begin(); i != in->end(); ++i) + out->push_back(const_cast(apply((*i)->as()))); + return out; + } else { + const_iterator i = find(in); + if (i != end()) { + const AType* out = i->second->as(); + if (out->kind == AType::EXPR && !out->concrete()) + out = const_cast(apply(out->as())); + return out; + } else { + return new AType(*in); + } + } + } +}; + +inline ostream& operator<<(ostream& out, const Subst& s) { + for (Subst::const_iterator i = s.begin(); i != s.end(); ++i) + out << i->first << " => " << i->second << endl; + return out; +} + /// Type constraint set struct Constraints : public list { Constraints() : list() {} + Constraints(const Subst& subst) : list() { + FOREACH(Subst::const_iterator, i, subst) { + push_back(Constraint(new AType(*i->first), new AType(*i->second), Cursor())); + } + } Constraints(const_iterator begin, const_iterator end) : list(begin, end) {} - void constrain(TEnv& tenv, const AST* o, AType* t); - Constraints& replace(AType* s, AType* t); + void constrain(TEnv& tenv, const AST* o, const AType* t); + Constraints replace(const AType* s, const AType* t); }; inline ostream& operator<<(ostream& out, const Constraints& c) { @@ -591,7 +620,7 @@ inline ostream& operator<<(ostream& out, const Constraints& c) { } /// Type-Time Environment -struct TEnv : public Env { +struct TEnv : public Env { TEnv(PEnv& p) : penv(p) , varID(1) @@ -600,10 +629,10 @@ struct TEnv : public Env { { Object::pool.addRoot(Fn); } - AType* fresh(const ASymbol* sym) { + const AType* fresh(const ASymbol* sym) { return def(sym, new AType(sym->loc, varID++)); } - AType* var(const AST* ast=0) { + const AType* var(const AST* ast=0) { if (!ast) return new AType(Cursor(), varID++); @@ -617,12 +646,12 @@ struct TEnv : public Env { return (vars[ast] = new AType(ast->loc, varID++)); } - AType* named(const string& name) { + const AType* named(const string& name) { return *ref(penv.sym(name)); } - static Subst buildSubst(AType* fnT, const AType& argsT); + static Subst buildSubst(const AType* fnT, const AType& argsT); - typedef map Vars; + typedef map Vars; Vars vars; PEnv& penv; @@ -650,20 +679,22 @@ struct Engine { const ATuple& argsT, const vector argNames=vector()) = 0; - virtual void finishFunction(CEnv& cenv, CFunc f, const AType* retT, CVal ret) = 0; - virtual void eraseFunction(CEnv& cenv, CFunc f) = 0; - virtual CFunc compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) = 0; - virtual CVal compileTup(CEnv& cenv, const AType* t, const vector& f) = 0; - virtual CVal compileDot(CEnv& cenv, CVal tup, int32_t index) = 0; - virtual CVal compileLiteral(CEnv& cenv, AST* lit) = 0; - virtual CVal compileCall(CEnv& cenv, CFunc f, const vector& args) = 0; - virtual CVal compilePrimitive(CEnv& cenv, APrimitive* prim) = 0; - virtual CVal compileIf(CEnv& cenv, AIf* aif) = 0; - virtual CVal compileGlobal(CEnv& cenv, AType* t, const string& sym, CVal val) = 0; - virtual CVal getGlobal(CEnv& cenv, CVal val) = 0; - virtual void writeModule(CEnv& cenv, std::ostream& os) = 0; - - virtual const string call(CEnv& cenv, CFunc f, AType* retT) = 0; + typedef const vector ValVec; + + virtual void finishFunction(CEnv& cenv, CFunc f, CVal ret) = 0; + virtual void eraseFunction(CEnv& cenv, CFunc f) = 0; + virtual CFunc compileFunction(CEnv& cenv, AFn* fn, const AType* type) = 0; + virtual CVal compileTup(CEnv& cenv, const AType* t, ValVec& f) = 0; + virtual CVal compileDot(CEnv& cenv, CVal tup, int32_t index) = 0; + virtual CVal compileLiteral(CEnv& cenv, AST* lit) = 0; + virtual CVal compileCall(CEnv& cenv, CFunc f, const AType* fT, ValVec& args) = 0; + virtual CVal compilePrimitive(CEnv& cenv, APrimitive* prim) = 0; + virtual CVal compileIf(CEnv& cenv, AIf* aif) = 0; + virtual CVal compileGlobal(CEnv& cenv, const AType* t, const string& sym, CVal val) = 0; + virtual CVal getGlobal(CEnv& cenv, CVal val) = 0; + virtual void writeModule(CEnv& cenv, std::ostream& os) = 0; + + virtual const string call(CEnv& cenv, CFunc f, const AType* retT) = 0; }; Engine* resp_new_llvm_engine(); @@ -683,14 +714,19 @@ struct CEnv { void push() { code.push(); tenv.push(); vals.push(); } void pop() { code.pop(); tenv.pop(); vals.pop(); } void lock(AST* ast) { Object::pool.addRoot(ast); Object::pool.addRoot(type(ast)); } - AType* type(AST* ast, const Subst& subst = Subst()) const { + const AType* type(AST* ast, const Subst& subst = Subst()) const { ASymbol* sym = ast->to(); - if (sym) - return *tenv.ref(sym); - assert(tenv.vars[ast]); - return tsubst.apply(subst.apply(tenv.vars[ast]))->to(); + if (sym) { + const AType** rec = tenv.ref(sym); + return rec ? *rec : NULL; + } + const AType* var = tenv.vars[ast]; + if (var) { + return tsubst.apply(subst.apply(var))->to(); + } + return NULL; } - void def(const ASymbol* sym, AST* c, AType* t, CVal v) { + void def(const ASymbol* sym, AST* c, const AType* t, CVal v) { code.def(sym, c); tenv.def(sym, t); vals.def(sym, v); @@ -700,6 +736,14 @@ struct CEnv { AST** rec = code.ref(sym); return rec ? *rec : ast; } + void setType(AST* ast, const AType* type) { + const AType* tvar = tenv.var(); + tenv.vars.insert(make_pair(ast, tvar)); + tsubst.add(tvar, type); + } + void setTypeSameAs(AST* ast, AST* typedAst) { + tenv.vars.insert(make_pair(ast, tenv.vars[typedAst])); + } ostream& out; ostream& err; @@ -710,8 +754,39 @@ struct CEnv { Env code; + typedef map Impls; + Impls impls; + + CFunc findImpl(AFn* fn, const AType* type) { + Impls::iterator i = impls.find(fn); + return (i != impls.end()) ? i->second : NULL; + } + + void addImpl(AFn* fn, CFunc impl) { + impls.insert(make_pair(fn, impl)); + } + map args; + CFunc currentFn; ///< Currently compiling function + + struct FreeVars : public std::vector { + FreeVars(AFn* f, const std::string& n) : fn(f), implName(n) {} + AFn* const fn; + const std::string implName; + int32_t index(ASymbol* sym) { + const_iterator i = find(begin(), end(), sym); + if (i != end()) { + return i - begin() + 1; + } else { + push_back(sym); + return size(); + } + } + }; + typedef std::stack LiftStack; + LiftStack liftStack; + private: Engine* _engine; }; diff --git a/src/unify.cpp b/src/unify.cpp index 93dd78b..1b25861 100644 --- a/src/unify.cpp +++ b/src/unify.cpp @@ -26,7 +26,7 @@ * with a specific set of argument types */ Subst -TEnv::buildSubst(AType* genericT, const AType& argsT) +TEnv::buildSubst(const AType* genericT, const AType& argsT) { Subst subst; @@ -56,7 +56,7 @@ TEnv::buildSubst(AType* genericT, const AType& argsT) } void -Constraints::constrain(TEnv& tenv, const AST* o, AType* t) +Constraints::constrain(TEnv& tenv, const AST* o, const AType* t) { assert(o); assert(t); @@ -64,15 +64,27 @@ Constraints::constrain(TEnv& tenv, const AST* o, AType* t) push_back(Constraint(tenv.var(o), t, o->loc)); } -static void -substitute(ATuple* tup, const AST* from, AST* to) +template +static const T* +substitute(const T* tup, const E* from, const E* to) { - if (!tup) return; - FOREACHP(ATuple::iterator, i, tup) - if (**i == *from) - *i = to; - else if (*i != to) - substitute((*i)->to(), from, to); + if (!tup) return NULL; + T* ret = new T(*tup); + typename T::iterator ri = ret->begin(); + FOREACHP(typename T::const_iterator, i, tup) { + if (**i == *from) { + *ri++ = const_cast(to); + } else if (static_cast(*i) != static_cast(to)) { + const T* subTup = dynamic_cast(*i); + if (subTup) + *ri++ = const_cast(substitute(subTup, from, to)); + else + *ri++ = *i; + } else { + ++ri; + } + } + return ret; } /// Compose two substitutions (TAPL 22.1.1) @@ -92,15 +104,16 @@ Subst::compose(const Subst& delta, const Subst& gamma) } /// Replace all occurrences of @a s with @a t -Constraints& -Constraints::replace(AType* s, AType* t) +Constraints +Constraints::replace(const AType* s, const AType* t) { + Constraints cp(*this); for (Constraints::iterator c = begin(); c != end();) { Constraints::iterator next = c; ++next; if (*c->first == *s) c->first = t; if (*c->second == *s) c->second = t; - substitute(c->first, s, t); - substitute(c->second, s, t); + c->first = substitute(c->first, s, t); + c->second = substitute(c->second, s, t); c = next; } return *this; @@ -114,8 +127,8 @@ unify(const Constraints& constraints) return Subst(); Constraints::const_iterator i = constraints.begin(); - AType* s = i->first; - AType* t = i->second; + const AType* s = i->first; + const AType* t = i->second; Constraints cp(++i, constraints.end()); if (*s == *t) { @@ -125,8 +138,8 @@ unify(const Constraints& constraints) } else if (t->kind == AType::VAR && !s->contains(t)) { return Subst::compose(unify(cp.replace(t, s)), Subst(t, s)); } else if (s->kind == AType::EXPR && t->kind == AType::EXPR) { - AType::iterator si = s->begin() + 1; - AType::iterator ti = t->begin() + 1; + AType::const_iterator si = s->begin() + 1; + AType::const_iterator ti = t->begin() + 1; for (; si != s->end() && ti != t->end(); ++si, ++ti) { AType* st = (*si)->as(); AType* tt = (*ti)->as(); -- cgit v1.2.1