From cf6c924f9cb10a583edbab2560773f2500a86323 Mon Sep 17 00:00:00 2001 From: David Robillard Date: Thu, 2 Dec 2010 08:03:47 +0000 Subject: Work towards removing different classes for each type of expression. git-svn-id: http://svn.drobilla.net/resp/resp@278 ad02d1e2-f140-0410-9f75-f8b11f17cedd --- src/c.cpp | 12 ++-- src/compile.cpp | 197 ++++++++++++++++++++++++++++++------------------------ src/constrain.cpp | 8 +-- src/lift.cpp | 169 ++++++++++++++++++++++++++++------------------ src/llvm.cpp | 18 ++--- src/repl.cpp | 10 +-- src/resp.hpp | 78 ++++++--------------- src/unify.cpp | 4 +- 8 files changed, 257 insertions(+), 239 deletions(-) diff --git a/src/c.cpp b/src/c.cpp index b1eafe5..4d53436 100644 --- a/src/c.cpp +++ b/src/c.cpp @@ -243,16 +243,16 @@ CEngine::compileIf(CEnv& cenv, const AIf* aif) if (idx > 1) out += "else {\n"; - Value* condV = llVal((*i)->compile(cenv)); + Value* condV = llVal(resp_compile(cenv, *i)); out += (format("if (%s) {\n") % *condV).str(); - Value* thenV = llVal((*next)->compile(cenv)); + Value* thenV = llVal(resp_compile(cenv, *next)); out += (format("%s = %s;\n}\n") % *varname % *thenV).str(); } // Emit final else block out += "else {\n"; - Value* elseV = llVal(aif->list_last()->compile(cenv)); + Value* elseV = llVal(resp_compile(cenv, aif->list_last())); out += (format("%s = %s;\n}\n") % *varname % *elseV).str(); for (size_t i = 1; i < idx / 2; ++i) @@ -273,8 +273,8 @@ CEngine::compilePrimitive(CEnv& cenv, const APrimitive* prim) APrimitive::const_iterator i = prim->begin(); ++i; - Value* a = llVal((*i++)->compile(cenv)); - Value* b = llVal((*i++)->compile(cenv)); + Value* a = llVal(resp_compile(cenv, *i++)); + Value* b = llVal(resp_compile(cenv, *i++)); const string n = prim->head()->to()->str(); string op = n; @@ -289,7 +289,7 @@ CEngine::compilePrimitive(CEnv& cenv, const APrimitive* prim) string val("("); val += *a + op + *b; while (i != prim->end()) - val += op + *llVal((*i++)->compile(cenv)); + val += op + *llVal(resp_compile(cenv, *i++)); val += ")"; Value* varname = new string(cenv.penv.gensymstr("x")); diff --git a/src/compile.cpp b/src/compile.cpp index da9683e..ccde2ca 100644 --- a/src/compile.cpp +++ b/src/compile.cpp @@ -23,62 +23,36 @@ using namespace std; -#define COMPILE_LITERAL(CT) \ -template<> CVal ALiteral::compile(CEnv& cenv) const throw() { \ - return cenv.engine()->compileLiteral(cenv, this); \ -} -COMPILE_LITERAL(int32_t); -COMPILE_LITERAL(float); -COMPILE_LITERAL(bool); - -CVal -AString::compile(CEnv& cenv) const throw() -{ - return cenv.engine()->compileString(cenv, c_str()); -} - -CVal -AQuote::compile(CEnv& cenv) const throw() -{ - return list_ref(1)->compile(cenv); -} - -CVal -ALexeme::compile(CEnv& cenv) const throw() +static CVal +compile_symbol(CEnv& cenv, const ASymbol* sym) throw() { - return cenv.engine()->compileString(cenv, c_str()); -} - -CVal -ASymbol::compile(CEnv& cenv) const throw() -{ - if (cenv.vals.topLevel(this) && cenv.type(this)->head()->str() != "Fn") { - return cenv.engine()->getGlobal(cenv, cppstr, *cenv.vals.ref(this)); + if (cenv.vals.topLevel(sym) && cenv.type(sym)->head()->str() != "Fn") { + return cenv.engine()->getGlobal(cenv, sym->cppstr, *cenv.vals.ref(sym)); } else { - return *cenv.vals.ref(this); + return *cenv.vals.ref(sym); } } -CVal -AFn::compile(CEnv& cenv) const throw() +static CVal +compile_fn(CEnv& cenv, const AFn* fn) throw() { - const AType* type = cenv.type(this); - CFunc f = cenv.findImpl(this, type); + const AType* type = cenv.type(fn); + CFunc f = cenv.findImpl(fn, type); if (f) return f; // Write function declaration - f = cenv.engine()->startFunction(cenv, name, prot(), type); + f = cenv.engine()->startFunction(cenv, fn->name, fn->prot(), type); // Create a new environment frame and bind argument values - cenv.engine()->pushFunctionArgs(cenv, this, type, f); + cenv.engine()->pushFunctionArgs(cenv, fn, type, f); assert(!cenv.currentFn); cenv.currentFn = f; // Write function body CVal retVal = NULL; - for (AFn::const_iterator i = iter_at(2); i != end(); ++i) - retVal = (*i)->compile(cenv); + for (AFn::const_iterator i = fn->iter_at(2); i != fn->end(); ++i) + retVal = resp_compile(cenv, *i); // Write function conclusion cenv.engine()->finishFunction(cenv, f, retVal); @@ -87,97 +61,144 @@ AFn::compile(CEnv& cenv) const throw() cenv.pop(); cenv.currentFn = NULL; - cenv.vals.def(cenv.penv.sym(name), f); - cenv.addImpl(this, f); + cenv.vals.def(cenv.penv.sym(fn->name), f); + cenv.addImpl(fn, f); return f; } -CVal -ACall::compile(CEnv& cenv) const throw() +static CVal +compile_call(CEnv& cenv, const ACall* call) throw() { - CFunc f = (*begin())->compile(cenv); + CFunc f = resp_compile(cenv, *call->begin()); if (!f) f = cenv.currentFn; // Recursive call (callee defined as a stub) vector args; - for (const_iterator e = iter_at(1); e != end(); ++e) - args.push_back((*e)->compile(cenv)); + for (ACall::const_iterator e = call->iter_at(1); e != call->end(); ++e) + args.push_back(resp_compile(cenv, *e)); - return cenv.engine()->compileCall(cenv, f, cenv.type(head()), args); + return cenv.engine()->compileCall(cenv, f, cenv.type(call->head()), args); } -CVal -ADef::compile(CEnv& cenv) const throw() +static CVal +compile_def(CEnv& cenv, const ADef* def) throw() { - cenv.def(sym(), body(), cenv.type(body()), NULL); // define stub first for recursion - CVal val = body()->compile(cenv); - if (cenv.vals.size() == 1 && cenv.type(body())->head()->str() != "Fn") { + const ASymbol* const sym = def->list_ref(1)->as(); + cenv.def(sym, def->body(), cenv.type(def->body()), NULL); // define stub first for recursion + CVal val = resp_compile(cenv, def->body()); + if (cenv.vals.size() == 1 && cenv.type(def->body())->head()->str() != "Fn") { val = cenv.engine()->compileGlobal( - cenv, cenv.type(body()), sym()->str(), val); - cenv.lock(this); + cenv, cenv.type(def->body()), sym->str(), val); + cenv.lock(def); } - cenv.vals.def(sym(), val); + cenv.vals.def(sym, val); return NULL; } -CVal -AIf::compile(CEnv& cenv) const throw() +static CVal +compile_cons(CEnv& cenv, const ACons* cons) throw() { - return cenv.engine()->compileIf(cenv, this); -} - -CVal -ACons::compile(CEnv& cenv) const throw() -{ - return ATuple::compile(cenv); -} - -CVal -ATuple::compile(CEnv& cenv) const throw() -{ - AType* type = new AType(const_cast(head()->as()), NULL, Cursor()); + AType* type = new AType(const_cast(cons->head()->as()), NULL, Cursor()); TList tlist(type); vector fields; - for (const_iterator i = iter_at(1); i != end(); ++i) { + for (ACons::const_iterator i = cons->iter_at(1); i != cons->end(); ++i) { tlist.push_back(const_cast(cenv.type(*i))); - fields.push_back((*i)->compile(cenv)); + fields.push_back(resp_compile(cenv, *i)); } - return cenv.engine()->compileTup(cenv, type, type->compile(cenv), fields); + return cenv.engine()->compileTup(cenv, type, resp_compile(cenv, type), fields); } -CVal -AType::compile(CEnv& cenv) const throw() +static CVal +compile_type(CEnv& cenv, const AType* type) throw() { - const ASymbol* sym = head()->as(); + const ASymbol* sym = type->head()->as(); CVal* existing = cenv.vals.ref(sym); if (existing) { return *existing; } else { - CVal compiled = cenv.engine()->compileString(cenv, (string("__T_") + head()->str()).c_str()); + CVal compiled = cenv.engine()->compileString( + cenv, (string("__T_") + type->head()->str()).c_str()); cenv.vals.def(sym, compiled); return compiled; } } -CVal -ADot::compile(CEnv& cenv) const throw() +static CVal +compile_dot(CEnv& cenv, const ADot* dot) throw() { - const_iterator i = begin(); + ATuple::const_iterator i = dot->begin(); const AST* tup = *++i; const ALiteral* index = (*++i)->as*>(); - CVal tupVal = tup->compile(cenv); + CVal tupVal = resp_compile(cenv, tup); return cenv.engine()->compileDot(cenv, tupVal, index->val); } CVal -APrimitive::compile(CEnv& cenv) const throw() +resp_compile(CEnv& cenv, const AST* ast) throw() { - return cenv.engine()->compilePrimitive(cenv, this); -} + if (ast->to*>() + || ast->to*>() + || ast->to*>()) + return cenv.engine()->compileLiteral(cenv, ast); + + const AString* str = ast->to(); + if (str) + return cenv.engine()->compileString(cenv, str->c_str()); + + const AQuote* quote = ast->to(); + if (quote) + return resp_compile(cenv, quote->list_ref(1)); + + const ALexeme* lexeme = ast->to(); + if (lexeme) + return cenv.engine()->compileString(cenv, lexeme->c_str()); + + const ASymbol* sym = ast->to(); + if (sym) + return compile_symbol(cenv, sym); + + const AFn* fn = ast->to(); + if (fn) + return compile_fn(cenv, fn); + + const ADef* def = ast->to(); + if (def) + return compile_def(cenv, def); -CVal -AMatch::compile(CEnv& cenv) const throw() -{ - return cenv.engine()->compileMatch(cenv, this); + const AIf* aif = ast->to(); + if (aif) + return cenv.engine()->compileIf(cenv, aif); + + const ACons* cons = ast->to(); + if (cons) + return compile_cons(cenv, cons); + + const APrimitive* prim = ast->to(); + if (prim) + return cenv.engine()->compilePrimitive(cenv, prim); + + const AMatch* match = ast->to(); + if (match) + return cenv.engine()->compileMatch(cenv, match); + + const AType* type = ast->to(); + if (type) + return compile_type(cenv, type); + + const ADot* dot = ast->to(); + if (dot) + return compile_dot(cenv, dot); + + const ADefType* deftype = ast->to(); + if (deftype) + return NULL; + + const ACall* call = ast->to(); + if (call) + return compile_call(cenv, call); + + cenv.err << "Attempt to compile unknown type" << endl; + assert(false); + return NULL; } diff --git a/src/constrain.cpp b/src/constrain.cpp index b9ce471..bd43b8e 100644 --- a/src/constrain.cpp +++ b/src/constrain.cpp @@ -100,11 +100,11 @@ AFn::constrain(TEnv& tenv, Constraints& c) const throw(Error) const AST* exp = *i; const ADef* def = exp->to(); if (def) { - const ASymbol* sym = def->sym(); + const ASymbol* sym = def->list_ref(1)->as(); THROW_IF(defs.count(sym) != 0, def->loc, (format("`%1%' defined twice") % sym->str()).str()); - defs.insert(def->sym()); - frame.push_back(make_pair(def->sym(), (AType*)NULL)); + defs.insert(sym); + frame.push_back(make_pair(sym, (AType*)NULL)); } } @@ -156,7 +156,7 @@ void ADef::constrain(TEnv& tenv, Constraints& c) const throw(Error) { THROW_IF(list_len() != 3, loc, "`def' requires exactly 2 arguments"); - const ASymbol* sym = this->sym(); + const ASymbol* sym = this->list_ref(1)->as(); THROW_IF(!sym, loc, "`def' has no symbol") const AType* tvar = tenv.var(body()); diff --git a/src/lift.cpp b/src/lift.cpp index 81b157d..754118a 100644 --- a/src/lift.cpp +++ b/src/lift.cpp @@ -27,56 +27,44 @@ using namespace std; -AST* -ASymbol::lift(CEnv& cenv, Code& code) throw() +static AST* +lift_symbol(CEnv& cenv, Code& code, ASymbol* sym) throw() { + const std::string& cppstr = sym->cppstr; if (!cenv.liftStack.empty() && cppstr == cenv.liftStack.top().fn->name) { return cenv.penv.sym("_me"); // Reference to innermost function } else if (!cenv.penv.handler(true, cppstr) && !cenv.penv.handler(false, cppstr) - && !cenv.code.innermost(this)) { + && !cenv.code.innermost(sym)) { - const int32_t index = cenv.liftStack.top().index(this); + const int32_t index = cenv.liftStack.top().index(sym); // Replace symbol with code to access free variable from closure - return tup(loc, cenv.penv.sym("."), + return tup(sym->loc, cenv.penv.sym("."), cenv.penv.sym("_me"), new ALiteral(index, Cursor()), NULL); } else { - return this; + return sym; } } -AST* -AQuote::lift(CEnv& cenv, Code& code) throw() -{ - return this; -} - -AST* -ATuple::lift(CEnv& cenv, Code& code) throw() -{ - assert(false); - return NULL; -} - -AST* -AFn::lift(CEnv& cenv, Code& code) throw() +static AST* +lift_fn(CEnv& cenv, Code& code, AFn* fn) throw() { - AFn* impl = new AFn(this); - const string nameBase = cenv.penv.gensymstr(((name != "") ? name : "fn").c_str()); + AFn* impl = new AFn(fn); + const string nameBase = cenv.penv.gensymstr(((fn->name != "") ? fn->name : "fn").c_str()); impl->name = "_" + nameBase; - cenv.liftStack.push(CEnv::FreeVars(this, impl->name)); + cenv.liftStack.push(CEnv::FreeVars(fn, impl->name)); // Create a new stub environment frame for parameters cenv.push(); - const AType* type = cenv.type(this); + const AType* type = cenv.type(fn); AType::const_iterator tp = type->prot()->begin(); AType* implProtT = new AType(*type->prot()->as()); ATuple::iterator ip = implProtT->begin(); - for (const_iterator p = prot()->begin(); p != prot()->end(); ++p) { + for (AFn::const_iterator p = fn->prot()->begin(); p != fn->prot()->end(); ++p) { const AType* paramType = (*tp++)->as(); if (paramType->kind == AType::EXPR && *paramType->head() == *cenv.tenv.Fn) { AType* fnType = new AType(*paramType); @@ -96,9 +84,9 @@ AFn::lift(CEnv& cenv, Code& code) throw() // Lift body const AType* implRetT = NULL; - iterator ci = impl->iter_at(2); - for (iterator i = iter_at(2); i != end(); ++i, ++ci) { - *ci = (*i)->lift(cenv, code); + AFn::iterator ci = impl->iter_at(2); + for (AFn::iterator i = fn->iter_at(2); i != fn->end(); ++i, ++ci) { + *ci = resp_lift(cenv, code, *i); implRetT = cenv.type(*ci); } @@ -109,13 +97,13 @@ AFn::lift(CEnv& cenv, Code& code) throw() // Create definition for implementation fn ASymbol* implName = cenv.penv.sym(impl->name); - ADef* def = tup(loc, cenv.penv.sym("def"), implName, impl, NULL); + ADef* def = tup(fn->loc, cenv.penv.sym("def"), implName, impl, NULL); code.push_back(def); AType* implT = new AType(*type); // Type of the implementation function - TList tupT(loc, cenv.tenv.Tup, cenv.tenv.var(), NULL); - TList consT(loc, cenv.tenv.Tup, implT, NULL); - List cons(loc, cenv.penv.sym("Closure"), implName, NULL); + TList tupT(fn->loc, cenv.tenv.Tup, cenv.tenv.var(), NULL); + TList consT(fn->loc, cenv.tenv.Tup, implT, NULL); + List cons(fn->loc, cenv.penv.sym("Closure"), implName, NULL); implT->list_ref(1) = implProtT; @@ -134,32 +122,32 @@ AFn::lift(CEnv& cenv, Code& code) throw() cenv.setType(cons, consT); cenv.def(implName, impl, implT, NULL); - if (name != "") - cenv.def(cenv.penv.sym(name), this, consT, NULL); + if (fn->name != "") + cenv.def(cenv.penv.sym(fn->name), fn, consT, NULL); return cons; } -AST* -ACall::lift(CEnv& cenv, Code& code) throw() +static AST* +lift_call(CEnv& cenv, Code& code, ACall* call) throw() { List copy; // Lift all children (callee and arguments, recursively) - for (iterator i = begin(); i != end(); ++i) - copy.push_back((*i)->lift(cenv, code)); + for (ATuple::iterator i = call->begin(); i != call->end(); ++i) + copy.push_back(resp_lift(cenv, code, *i)); - copy.head->loc = loc; + copy.head->loc = call->loc; const AType* copyT = NULL; - ASymbol* sym = head()->to(); + ASymbol* sym = call->head()->to(); if (sym && !cenv.liftStack.empty() && sym->cppstr == cenv.liftStack.top().fn->name) { /* Recursive call to innermost function, call implementation directly, * reusing the current "_me" closure parameter (no cons or .). */ copy.push_front(cenv.penv.sym(cenv.liftStack.top().implName)); - } else if (head()->to()) { + } else if (call->head()->to()) { /* Special case: ((fn ...) ...) * Lifting (fn ...) yields: (Fn _impl ...). * We don't want ((Fn _impl ...) (Fn _impl ...) ...), @@ -174,7 +162,7 @@ ACall::lift(CEnv& cenv, Code& code) throw() copyT = implT->list_ref(2)->as(); } else { // Call to a closure, prepend code to access implementation function - ADot* getFn = tup(loc, cenv.penv.sym("."), + ADot* getFn = tup(call->loc, cenv.penv.sym("."), copy.head->head(), new ALiteral(0, Cursor()), NULL); const AType* calleeT = cenv.type(copy.head->head()); @@ -189,47 +177,96 @@ ACall::lift(CEnv& cenv, Code& code) throw() return copy; } -AST* -ADef::lift(CEnv& cenv, Code& code) throw() +static AST* +lift_def(CEnv& cenv, Code& code, ADef* def) throw() { // Define stub first for recursion - cenv.def(sym(), body(), cenv.type(body()), NULL); - AFn* c = body()->to(); + const ASymbol* const sym = def->list_ref(1)->as(); + cenv.def(sym, def->body(), cenv.type(def->body()), NULL); + AFn* c = def->body()->to(); if (c) - c->name = sym()->str(); + c->name = sym->str(); - assert(list_ref(1)->to()); + assert(def->list_ref(1)->to()); List copy; - copy.push_back(head()); - copy.push_back(list_ref(1)->lift(cenv, code)); - for (iterator t = iter_at(2); t != end(); ++t) - copy.push_back((*t)->lift(cenv, code)); + copy.push_back(def->head()); + copy.push_back(resp_lift(cenv, code, def->list_ref(1))); + for (ADef::iterator t = def->iter_at(2); t != def->end(); ++t) + copy.push_back(resp_lift(cenv, code, *t)); - cenv.setTypeSameAs(copy, this); + cenv.setTypeSameAs(copy, def); - if (copy.head->sym() == copy.head->body()) + if (copy.head->list_ref(1) == copy.head->list_ref(2)) return NULL; // Definition created by AFn::lift when body was lifted - cenv.def(copy.head->sym(), copy.head->body(), cenv.type(copy.head->body()), NULL); + cenv.def(copy.head->list_ref(1)->as(), + copy.head->list_ref(2), + cenv.type(copy.head->list_ref(2)), + NULL); return copy; } template -AST* -lift_builtin_call(CEnv& cenv, T* call, Code& code) throw() +static AST* +lift_builtin_call(CEnv& cenv, Code& code, ACall* call) throw() { - ATuple* copy = new T(call); - ATuple::iterator ri = copy->iter_at(1); + List copy; + copy.push_back(call->head()); // Lift all arguments - for (typename T::iterator i = call->iter_at(1); i != call->end(); ++i) - *ri++ = (*i)->lift(cenv, code); + for (ATuple::iterator i = call->iter_at(1); i != call->end(); ++i) + copy.push_back(resp_lift(cenv, code, *i)); cenv.setTypeSameAs(copy, call); return copy; } -AST* AIf::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); } -AST* ACons::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); } -AST* ADot::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); } -AST* APrimitive::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); } +AST* +resp_lift(CEnv& cenv, Code& code, AST* ast) throw() +{ + ASymbol* const sym = ast->to(); + if (sym) + return lift_symbol(cenv, code, sym); + + ADef* const def = ast->to(); + if (def) + return lift_def(cenv, code, def); + + AFn* const fn = ast->to(); + if (fn) + return lift_fn(cenv, code, fn); + + AIf* const aif = ast->to(); + if (aif) + return lift_builtin_call(cenv, code, aif); + + ACons* const cons = ast->to(); + if (cons) + return lift_builtin_call(cenv, code, cons); + + ADot* const dot = ast->to(); + if (dot) + return lift_builtin_call(cenv, code, dot); + + AQuote* const quote = ast->to(); + if (quote) + return lift_builtin_call(cenv, code, quote); + + AMatch* const match = ast->to(); + if (match) + return match; // FIXME + + APrimitive* const prim = ast->to(); + if (prim) + return lift_builtin_call(cenv, code, prim); + + ADefType* const defType = ast->to(); + if (defType) + return defType; + + ACall* const call = ast->to(); + if (call) + return lift_call(cenv, code, call); + + return ast; +} diff --git a/src/llvm.cpp b/src/llvm.cpp index 3e55b98..3aaf21b 100644 --- a/src/llvm.cpp +++ b/src/llvm.cpp @@ -365,7 +365,7 @@ LLVMEngine::compileIf(CEnv& cenv, const AIf* aif) if (++next == aif->end()) break; - Value* condV = llVal((*i)->compile(cenv)); + Value* condV = llVal(resp_compile(cenv, *i)); BasicBlock* thenBB = BasicBlock::Create(context, (format("then%1%") % ((idx+1)/2)).str()); nextBB = BasicBlock::Create(context, (format("else%1%") % ((idx+1)/2)).str()); @@ -375,7 +375,7 @@ LLVMEngine::compileIf(CEnv& cenv, const AIf* aif) // Emit then block for this condition parent->getBasicBlockList().push_back(thenBB); engine->builder.SetInsertPoint(thenBB); - Value* thenV = llVal((*next)->compile(cenv)); + Value* thenV = llVal(resp_compile(cenv, *next)); engine->builder.CreateBr(mergeBB); branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock())); @@ -386,7 +386,7 @@ LLVMEngine::compileIf(CEnv& cenv, const AIf* aif) } // Emit final else block - Value* elseV = llVal(aif->list_last()->compile(cenv)); + Value* elseV = llVal(resp_compile(cenv, aif->list_last())); engine->builder.CreateBr(mergeBB); branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock())); @@ -405,7 +405,7 @@ CVal LLVMEngine::compileMatch(CEnv& cenv, const AMatch* match) { typedef vector< pair > Branches; - Value* matchee = llVal(match->list_ref(1)->compile(cenv)); + Value* matchee = llVal(resp_compile(cenv, match->list_ref(1))); Value* rttiPtr = builder.CreateStructGEP(matchee, 0, "matchRTTIPtr"); Value* rtti = builder.CreateLoad(rttiPtr, 0, "matchRTTI"); @@ -422,7 +422,7 @@ LLVMEngine::compileMatch(CEnv& cenv, const AMatch* match) const ASymbol* sym = pat->to()->head()->as(); const AType* patT = tup(Cursor(), const_cast(sym), 0); - Value* typeV = llVal(patT->compile(cenv)); + Value* typeV = llVal(resp_compile(cenv, patT)); Value* condV = engine->builder.CreateICmp(CmpInst::ICMP_EQ, rtti, typeV); BasicBlock* thenBB = BasicBlock::Create(context, (format("case%1%") % ((idx+1)/2)).str()); @@ -433,7 +433,7 @@ LLVMEngine::compileMatch(CEnv& cenv, const AMatch* match) // Emit then block for this condition parent->getBasicBlockList().push_back(thenBB); engine->builder.SetInsertPoint(thenBB); - Value* thenV = llVal(body->compile(cenv)); + Value* thenV = llVal(resp_compile(cenv, body)); engine->builder.CreateBr(mergeBB); branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock())); @@ -463,8 +463,8 @@ LLVMEngine::compilePrimitive(CEnv& cenv, const APrimitive* prim) LLVMEngine* engine = reinterpret_cast(cenv.engine()); bool isFloat = cenv.type(*++i)->str() == "Float"; - Value* a = llVal((*i++)->compile(cenv)); - Value* b = llVal((*i++)->compile(cenv)); + Value* a = llVal(resp_compile(cenv, *i++)); + Value* b = llVal(resp_compile(cenv, *i++)); const string n = prim->head()->to()->str(); // Binary arithmetic operations @@ -480,7 +480,7 @@ LLVMEngine::compilePrimitive(CEnv& cenv, const APrimitive* prim) if (op != 0) { Value* val = engine->builder.CreateBinOp(op, a, b); while (i != prim->end()) - val = engine->builder.CreateBinOp(op, val, llVal((*i++)->compile(cenv))); + val = engine->builder.CreateBinOp(op, val, llVal(resp_compile(cenv, *i++))); return val; } diff --git a/src/repl.cpp b/src/repl.cpp index d0a251b..923c26d 100644 --- a/src/repl.cpp +++ b/src/repl.cpp @@ -100,7 +100,7 @@ eval(CEnv& cenv, Cursor& cursor, istream& is, bool execute) // Lift all expressions Code lifted; for (list::iterator i = parsed.begin(); i != parsed.end(); ++i) { - AST* l = (*i)->lift(cenv, lifted); + AST* l = resp_lift(cenv, lifted, *i); if (l) lifted.push_back(l); } @@ -116,7 +116,7 @@ eval(CEnv& cenv, Cursor& cursor, istream& is, bool execute) for (Code::const_iterator i = lifted.begin(); i != lifted.end(); ++i) { const ADef* def = (*i)->to(); if (def && def->list_ref(2)->to()) { - val = def->compile(cenv); + val = resp_compile(cenv, def); } else { assert(*i); ATuple* tup = (*i)->to(); @@ -133,7 +133,7 @@ eval(CEnv& cenv, Cursor& cursor, istream& is, bool execute) // Compile expressions (other than function definitions) into it for (list::const_iterator i = exprs.begin(); i != exprs.end(); ++i) - val = (*i)->compile(cenv); + val = resp_compile(cenv, *i); // Finish compilation cenv.engine()->finishFunction(cenv, f, val); @@ -170,14 +170,14 @@ repl(CEnv& cenv) break; Code lifted; - ast = ast->lift(cenv, lifted); + ast = resp_lift(cenv, lifted, ast); const AType* type = cenv.type(ast); const AType* fnT = tup(cursor, cenv.tenv.Fn, new AType(cursor), type, 0); CFunc f = NULL; try { // Create function for this repl loop f = cenv.engine()->startFunction(cenv, replFnName, new ATuple(cursor), fnT); - cenv.engine()->finishFunction(cenv, f, ast->compile(cenv)); + cenv.engine()->finishFunction(cenv, f, resp_compile(cenv, ast)); callPrintCollect(cenv, f, ast, type, true); if (cenv.args.find("-d") != cenv.args.end()) cenv.engine()->writeModule(cenv, cenv.out); diff --git a/src/resp.hpp b/src/resp.hpp index 26038b6..632e1a2 100644 --- a/src/resp.hpp +++ b/src/resp.hpp @@ -200,12 +200,8 @@ typedef list Code; struct AST : public Object { AST(Cursor c=Cursor()) : loc(c) {} virtual ~AST() {} - virtual bool value() const { return true; } virtual bool operator==(const AST& o) const = 0; - virtual bool contains(const AST* child) const { return false; } virtual void constrain(TEnv& tenv, Constraints& c) const throw(Error) {} - virtual AST* lift(CEnv& cenv, Code& code) throw() { return this; } - virtual CVal compile(CEnv& env) const throw() = 0; string str() const { ostringstream ss; ss << this; return ss.str(); } template T to() { return dynamic_cast(this); } template T const to() const { return dynamic_cast(this); } @@ -239,7 +235,6 @@ struct ALiteral : public AST { return (r && (val == r->val)); } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - CVal compile(CEnv& env) const throw(); const T val; }; @@ -248,7 +243,6 @@ struct ALexeme : public AST, public std::string { ALexeme(Cursor c, const string& s) : AST(c), std::string(s) {} bool operator==(const AST& rhs) const { return this == &rhs; } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - CVal compile(CEnv& env) const throw(); }; /// String, e.g. ""a"" @@ -256,15 +250,12 @@ struct AString : public AST, public std::string { AString(Cursor c, const string& s) : AST(c), std::string(s) {} bool operator==(const AST& rhs) const { return this == &rhs; } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - CVal compile(CEnv& env) const throw(); }; /// Symbol, e.g. "a" struct ASymbol : public AST { bool operator==(const AST& rhs) const { return this == &rhs; } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); const string cppstr; private: friend class PEnv; @@ -416,7 +407,6 @@ struct ATuple : public AST { AST*& list_ref(unsigned index) { return *iter_at(index); } const AST* list_ref(unsigned index) const { return *iter_at(index); } - bool value() const { return false; } bool operator==(const AST& rhs) const { const ATuple* rt = rhs.to(); if (!rt || rt->tup_len() != tup_len()) return false; @@ -426,22 +416,29 @@ struct ATuple : public AST { return false; return true; } - bool contains(const AST* child) const { - if (*this == *child) return true; - FOREACHP(const_iterator, p, this) - if (**p == *child || (*p)->contains(child)) - return true; - return false; - } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); private: size_t _len; AST** _vec; }; +static bool +list_contains(const ATuple* head, const AST* child) { + if (*head == *child) + return true; + + FOREACHP(ATuple::const_iterator, p, head) { + if (**p == *child) + return true; + + const ATuple* tup = (*p)->to(); + if (tup && list_contains(tup, child)) + return true; + } + + return false; +} /// Type Expression, e.g. "Int", "(Fn (Int Int) Float)" struct AType : public ATuple { @@ -452,7 +449,6 @@ struct AType : public ATuple { AType(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args), kind(EXPR), id(0) {} AType(AST* first, AST* rest, Cursor c) : ATuple(first, rest, c), kind(EXPR), id(0) {} AType(const AType& copy) : ATuple(copy), kind(copy.kind), id(copy.id) {} - CVal compile(CEnv& cenv) const throw(); const ATuple* prot() const { assert(kind == EXPR); return list_ref(1)->to(); } ATuple* prot() { assert(kind == EXPR); return list_ref(1)->to(); } void prot(ATuple* prot) { assert(kind == EXPR); *iter_at(1) = prot; } @@ -534,8 +530,6 @@ struct AFn : public ATuple { AFn(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args) {} bool operator==(const AST& rhs) const { return this == &rhs; } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); const ATuple* prot() const { return list_ref(1)->to(); } ATuple* prot() { return list_ref(1)->to(); } void prot(ATuple* prot) { *iter_at(1) = prot; } @@ -548,8 +542,6 @@ struct ACall : public ATuple { ACall(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args) {} ACall(AST* first, AST* rest, Cursor c) : ATuple(first, rest, c) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); }; /// Definition special form, e.g. "(def x 2)" @@ -557,31 +549,16 @@ struct ADef : public ACall { ADef(const ATuple* exp) : ACall(exp) {} ADef(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {} ADef(AST* first, AST* rest, Cursor c) : ACall(first, rest, c) {} - const ASymbol* sym() const { - const AST* name = list_ref(1); - const ASymbol* sym = name->to(); - if (!sym) { - const ATuple* tup = name->to(); - if (tup && !tup->empty()) - return tup->head()->to(); - } - return sym; - } const AST* body() const { return list_ref(2); } AST* body() { return list_ref(2); } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); }; struct ADefType : public ACall { ADefType(const ATuple* exp) : ACall(exp) {} ADefType(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {} ADefType(AST* first, AST* rest, Cursor c) : ACall(first, rest, c) {} - const ASymbol* sym() const { return list_ref(1)->as(); } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw() { return this; } - CVal compile(CEnv& env) const throw() { return NULL; } }; struct AMatch : public ACall { @@ -589,8 +566,6 @@ struct AMatch : public ACall { AMatch(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {} AMatch(AST* first, AST* rest, Cursor c) : ACall(first, rest, c) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw() { return this; } - CVal compile(CEnv& env) const throw(); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" @@ -599,8 +574,6 @@ struct AIf : public ACall { AIf(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {} AIf(AST* first, AST* rest, Cursor c) : ACall(first, rest, c) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); }; struct ACons : public ACall { @@ -608,8 +581,6 @@ struct ACons : public ACall { ACons(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {} ACons(AST* first, AST* rest, Cursor c) : ACall(first, rest, c) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); }; struct ADot : public ACall { @@ -617,32 +588,19 @@ struct ADot : public ACall { ADot(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {} ADot(AST* first, AST* rest, Cursor c) : ACall(first, rest, c) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct APrimitive : public ACall { APrimitive(const ATuple* exp) : ACall(exp) {} APrimitive(AST* first, AST* rest, Cursor c) : ACall(first, rest, c) {} - bool value() const { - ATuple::const_iterator i = begin(); - for (++i; i != end(); ++i) - if (!(*i)->value()) - return false;; - return true; - } void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); }; struct AQuote : public ACall { AQuote(const ATuple* exp) : ACall(exp) {} AQuote(AST* first, AST* rest, Cursor c) : ACall(first, rest, c) {} void constrain(TEnv& tenv, Constraints& c) const throw(Error); - AST* lift(CEnv& cenv, Code& code) throw(); - CVal compile(CEnv& env) const throw(); }; @@ -737,7 +695,7 @@ struct Subst : public list { if (find(type) != end()) return true; FOREACHP(const_iterator, j, this) - if (*j->second == *type || j->second->contains(type)) + if (*j->second == *type || list_contains(j->second, type)) return true; return false; } @@ -954,5 +912,7 @@ void pprint(std::ostream& out, const AST* ast, CEnv* cenv, bool types); void initLang(PEnv& penv, TEnv& tenv); int eval(CEnv& cenv, Cursor& cursor, istream& is, bool execute); int repl(CEnv& cenv); +AST* resp_lift(CEnv& cenv, Code& code, AST* ast) throw(); +CVal resp_compile(CEnv& cenv, const AST* ast) throw(); #endif // RESP_HPP diff --git a/src/unify.cpp b/src/unify.cpp index 9165d49..ec5aa9a 100644 --- a/src/unify.cpp +++ b/src/unify.cpp @@ -140,9 +140,9 @@ unify(const Constraints& constraints) if (*s == *t) { return unify(cp); - } else if (s->kind == AType::VAR && !t->contains(s)) { + } else if (s->kind == AType::VAR && !list_contains(t, s)) { return Subst::compose(unify(cp.replace(s, t)), Subst(s, t)); - } else if (t->kind == AType::VAR && !s->contains(t)) { + } else if (t->kind == AType::VAR && !list_contains(s, t)) { return Subst::compose(unify(cp.replace(t, s)), Subst(t, s)); } else if (s->kind == AType::EXPR && t->kind == AType::EXPR) { AType::const_iterator si = s->begin(); -- cgit v1.2.1