From fd332f979c4216a560925639ae99f2155ab56044 Mon Sep 17 00:00:00 2001 From: David Robillard Date: Fri, 10 Dec 2010 18:28:49 +0000 Subject: Simplify if into nested 2-branch (scheme style) ifs at simplify stage. git-svn-id: http://svn.drobilla.net/resp/resp@346 ad02d1e2-f140-0410-9f75-f8b11f17cedd --- src/c.cpp | 27 ++++++---------------- src/compile.cpp | 17 +++----------- src/llvm.cpp | 70 +++++++++++++++++++------------------------------------- src/resp.hpp | 27 +++++++++++----------- src/simplify.cpp | 33 ++++++++++++++++++++++++-- 5 files changed, 79 insertions(+), 95 deletions(-) diff --git a/src/c.cpp b/src/c.cpp index 5bd3036..ae0d5e1 100644 --- a/src/c.cpp +++ b/src/c.cpp @@ -51,9 +51,7 @@ struct CEngine : public Engine { CVal compileDot(CEnv& cenv, CVal tup, int32_t index); CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AType* t); CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v); - IfState compileIfStart(CEnv& cenv); - void compileIfBranch(CEnv& cenv, IfState state, CVal condV, const AST* then); - CVal compileIfEnd(CEnv& cenv, IfState state, CVal elseV, const AType* type); + CVal compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse); CVal compileIsA(CEnv& cenv, CVal rtti, const ASymbol* tag); CVal compileLiteral(CEnv& cenv, const AST* lit); CVal compilePrimitive(CEnv& cenv, const ATuple* prim); @@ -228,6 +226,12 @@ CEngine::eraseFn(CEnv& cenv, CFunc f) cenv.err << "C backend does not support JIT (eraseFn)" << endl; } +CVal +CEngine::compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse) +{ + return NULL; +} + #if 0 CVal CEngine::compileIf(CEnv& cenv, const ATuple* aif) @@ -262,23 +266,6 @@ CEngine::compileIf(CEnv& cenv, const ATuple* aif) } #endif -IfState -CEngine::compileIfStart(CEnv& cenv) -{ - return NULL; -} - -void -CEngine::compileIfBranch(CEnv& cenv, IfState s, CVal condV, const AST* then) -{ -} - -CVal -CEngine::compileIfEnd(CEnv& cenv, IfState s, CVal elseV, const AType* type) -{ - return NULL; -} - CVal CEngine::compileIsA(CEnv& cenv, CVal rtti, const ASymbol* tag) { diff --git a/src/compile.cpp b/src/compile.cpp index 123b403..4c7980e 100644 --- a/src/compile.cpp +++ b/src/compile.cpp @@ -109,22 +109,11 @@ compile_fn(CEnv& cenv, const ATuple* fn) throw() static CVal compile_if(CEnv& cenv, const ATuple* aif) throw() { - IfState state = cenv.engine()->compileIfStart(cenv); - for (ATuple::const_iterator i = aif->iter_at(1); ; ++i) { - ATuple::const_iterator next = i; - if (++next == aif->end()) - break; - - cenv.engine()->compileIfBranch(cenv, state, resp_compile(cenv, *i), *next); - - i = next; // jump 2 each iteration (to the next predicate) - } - - CVal elseV = NULL; + const AST* aelse = NULL; if (*aif->list_last() != *cenv.penv.sym("__unreachable")) - elseV = resp_compile(cenv, aif->list_last()); + aelse = aif->list_ref(3); - return cenv.engine()->compileIfEnd(cenv, state, elseV, cenv.type(aif)); + return cenv.engine()->compileIf(cenv, aif->list_ref(1), aif->list_ref(2), aelse); } static CVal diff --git a/src/llvm.cpp b/src/llvm.cpp index de53baa..69d421a 100644 --- a/src/llvm.cpp +++ b/src/llvm.cpp @@ -66,9 +66,7 @@ struct LLVMEngine : public Engine { CVal compileDot(CEnv& cenv, CVal tup, int32_t index); CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AType* t); CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v); - IfState compileIfStart(CEnv& cenv); - void compileIfBranch(CEnv& cenv, IfState state, CVal condV, const AST* then); - CVal compileIfEnd(CEnv& cenv, IfState state, CVal elseV, const AType* type); + CVal compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse); CVal compileIsA(CEnv& cenv, CVal rtti, const ASymbol* tag); CVal compileLiteral(CEnv& cenv, const AST* lit); CVal compilePrimitive(CEnv& cenv, const ATuple* prim); @@ -79,16 +77,6 @@ struct LLVMEngine : public Engine { const string call(CEnv& cenv, CFunc f, const AType* retT); private: - typedef pair IfBranch; - typedef vector IfBranches; - - struct LLVMIfState { - LLVMIfState(BasicBlock* m, Function* p) : mergeBB(m), parent(p) {} - BasicBlock* mergeBB; - Function* parent; - IfBranches branches; - }; - void appendBlock(LLVMEngine* engine, Function* function, BasicBlock* block) { function->getBasicBlockList().push_back(block); engine->builder.SetInsertPoint(block); @@ -336,53 +324,43 @@ LLVMEngine::eraseFn(CEnv& cenv, CFunc f) llFunc(f)->eraseFromParent(); } -IfState -LLVMEngine::compileIfStart(CEnv& cenv) -{ - LLVMEngine* engine = reinterpret_cast(cenv.engine()); - return new LLVMIfState(BasicBlock::Create(context, "endif"), - engine->builder.GetInsertBlock()->getParent()); -} - -void -LLVMEngine::compileIfBranch(CEnv& cenv, IfState s, CVal condV, const AST* then) +CVal +LLVMEngine::compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse) { - LLVMIfState* state = (LLVMIfState*)s; - LLVMEngine* engine = reinterpret_cast(cenv.engine()); - BasicBlock* thenBB = BasicBlock::Create(context, (format("then%1%") % labelIndex).str()); - BasicBlock* nextBB = BasicBlock::Create(context, (format("else%1%") % labelIndex).str()); - + LLVMEngine* engine = reinterpret_cast(cenv.engine()); + BasicBlock* mergeBB = BasicBlock::Create(context, "endif"); + Function* parent = engine->builder.GetInsertBlock()->getParent(); + BasicBlock* thenBB = BasicBlock::Create(context, (format("then%1%") % labelIndex).str()); + BasicBlock* nextBB = BasicBlock::Create(context, (format("else%1%") % labelIndex).str()); + + const AType* type = cenv.type(then); + ++labelIndex; - engine->builder.CreateCondBr(llVal(condV), thenBB, nextBB); + engine->builder.CreateCondBr(llVal(resp_compile(cenv, cond)), thenBB, nextBB); // Emit then block for this condition - appendBlock(engine, state->parent, thenBB); + appendBlock(engine, parent, thenBB); Value* thenV = llVal(resp_compile(cenv, then)); - engine->builder.CreateBr(state->mergeBB); - state->branches.push_back(make_pair(thenV, thenBB)); - - appendBlock(engine, state->parent, nextBB); -} + engine->builder.CreateBr(mergeBB); -CVal -LLVMEngine::compileIfEnd(CEnv& cenv, IfState s, CVal elseV, const AType* type) -{ - LLVMIfState* state = (LLVMIfState*)s; - LLVMEngine* engine = reinterpret_cast(cenv.engine()); + appendBlock(engine, parent, nextBB); - if (!elseV) + Value* elseV = NULL; + if (aelse) + elseV = llVal(resp_compile(cenv, aelse)); + else elseV = Constant::getNullValue(llType(type)); // Emit end of final else block - engine->builder.CreateBr(state->mergeBB); - state->branches.push_back(make_pair(llVal(elseV), engine->builder.GetInsertBlock())); + engine->builder.CreateBr(mergeBB); + BasicBlock* elseBB = engine->builder.GetInsertBlock(); // Emit merge block (Phi node) - appendBlock(engine, state->parent, state->mergeBB); + appendBlock(engine, parent, mergeBB); PHINode* pn = engine->builder.CreatePHI(llType(type), "ifval"); - FOREACH(IfBranches::iterator, i, state->branches) - pn->addIncoming(i->first, i->second); + pn->addIncoming(thenV, thenBB); + pn->addIncoming(elseV, elseBB); return pn; } diff --git a/src/resp.hpp b/src/resp.hpp index a2b5cf8..439d00b 100644 --- a/src/resp.hpp +++ b/src/resp.hpp @@ -587,7 +587,10 @@ struct Subst : public list { if (s && t) { assert(s != t); push_back(Constraint(s, t)); } } static Subst compose(const Subst& delta, const Subst& gamma); - void add(const AType* from, const AType* to) { push_back(Constraint(from, to)); } + void add(const AType* from, const AType* to) { + assert(from && to); + push_back(Constraint(from, to)); + } const_iterator find(const AType* t) const { for (const_iterator j = begin(); j != end(); ++j) if (*j->first == *t) @@ -723,18 +726,16 @@ struct Engine { virtual void finishFn(CEnv& cenv, CFunc f, CVal ret) = 0; virtual void eraseFn(CEnv& cenv, CFunc f) = 0; - virtual CVal compileCall(CEnv& cenv, CFunc f, const AType* fT, CVals& args) = 0; - virtual CVal compileCons(CEnv& cenv, const AType* t, CVal rtti, CVals& f) = 0; - virtual CVal compileDot(CEnv& cenv, CVal tup, int32_t index) = 0; - virtual CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AType* t) = 0; - virtual CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v) = 0; - virtual IfState compileIfStart(CEnv& cenv) = 0; - virtual void compileIfBranch(CEnv& cenv, IfState state, CVal condV, const AST* then) = 0; - virtual CVal compileIfEnd(CEnv& cenv, IfState state, CVal elseV, const AType* type) = 0; - virtual CVal compileIsA(CEnv& cenv, CVal rtti, const ASymbol* tag) = 0; - virtual CVal compileLiteral(CEnv& cenv, const AST* lit) = 0; - virtual CVal compilePrimitive(CEnv& cenv, const ATuple* prim) = 0; - virtual CVal compileString(CEnv& cenv, const char* str) = 0; + virtual CVal compileCall(CEnv& cenv, CFunc f, const AType* fT, CVals& args) = 0; + virtual CVal compileCons(CEnv& cenv, const AType* t, CVal rtti, CVals& f) = 0; + virtual CVal compileDot(CEnv& cenv, CVal tup, int32_t index) = 0; + virtual CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AType* t) = 0; + virtual CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v) = 0; + virtual CVal compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse) = 0; + virtual CVal compileIsA(CEnv& cenv, CVal rtti, const ASymbol* tag) = 0; + virtual CVal compileLiteral(CEnv& cenv, const AST* lit) = 0; + virtual CVal compilePrimitive(CEnv& cenv, const ATuple* prim) = 0; + virtual CVal compileString(CEnv& cenv, const char* str) = 0; virtual void writeModule(CEnv& cenv, std::ostream& os) = 0; diff --git a/src/simplify.cpp b/src/simplify.cpp index 31eb83a..35178ac 100644 --- a/src/simplify.cpp +++ b/src/simplify.cpp @@ -25,6 +25,34 @@ using namespace std; +static const AST* +simplify_if(CEnv& cenv, const ATuple* aif) throw() +{ + List copy(aif->loc, cenv.penv.sym("if"), NULL); + copy.push_back(aif->list_ref(1)); + copy.push_back(aif->list_ref(2)); + + ATuple* tail = copy.tail; + ATuple::const_iterator i = aif->iter_at(3); + for (; ; ++i) { + ATuple::const_iterator next = i; + if (++next == aif->end()) + break; + + List inner_if((*i)->loc, cenv.penv.sym("if"), *i, *next, NULL); + tail->last(new ATuple(inner_if.head, NULL, Cursor())); + tail = inner_if.tail; + + cenv.setTypeSameAs(inner_if, aif); + + i = next; // jump 2 elements (to the next predicate) + } + + tail->last(new ATuple(*i, NULL, Cursor())); + cenv.setTypeSameAs(copy, aif); + return copy; +} + static const AST* simplify_match(CEnv& cenv, const ATuple* match) throw() { @@ -57,9 +85,8 @@ simplify_match(CEnv& cenv, const ATuple* match) throw() copyIf.push_back(resp_simplify(cenv, body)); } copyIf.push_back(cenv.penv.sym("__unreachable")); - copy.push_back(copyIf); - cenv.setTypeSameAs(copyIf, match); + copy.push_back(simplify_if(cenv, copyIf)); cenv.setTypeSameAs(copy, match); return copy; @@ -89,6 +116,8 @@ resp_simplify(CEnv& cenv, const AST* ast) throw() if (form == "match") return simplify_match(cenv, list); + else if (form == "if") + return simplify_if(cenv, list); else return simplify_list(cenv, list); } -- cgit v1.2.1