diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/compile.cpp | 12 | ||||
-rw-r--r-- | src/llvm.cpp | 110 |
2 files changed, 92 insertions, 30 deletions
diff --git a/src/compile.cpp b/src/compile.cpp index 44017ec..ba8cc5e 100644 --- a/src/compile.cpp +++ b/src/compile.cpp @@ -167,11 +167,18 @@ compile_fn(CEnv& cenv, const ATuple* fn) throw() static CVal compile_if(CEnv& cenv, const ATuple* aif) throw() { + const AST* cond = aif->list_ref(1); + const AST* then = aif->list_ref(2); const AST* aelse = NULL; if (*aif->list_last() != *cenv.penv.sym("__unreachable")) aelse = aif->list_ref(3); - return cenv.engine()->compileIf(cenv, aif->list_ref(1), aif->list_ref(2), aelse); + const AST* type = cenv.type(then); + + cenv.engine()->compileIfStart(cenv, cond, type); + cenv.engine()->compileIfThen(cenv, resp_compile(cenv, then)); + cenv.engine()->compileIfElse(cenv, resp_compile(cenv, aelse)); + return cenv.engine()->compileIfEnd(cenv); } static CVal @@ -214,6 +221,9 @@ compile_call(CEnv& cenv, const ATuple* call) throw() CVal resp_compile(CEnv& cenv, const AST* ast) throw() { + if (!ast) + return NULL; + switch (ast->tag()) { case T_UNKNOWN: return NULL; diff --git a/src/llvm.cpp b/src/llvm.cpp index 2c80ba7..06a895a 100644 --- a/src/llvm.cpp +++ b/src/llvm.cpp @@ -26,6 +26,7 @@ #include <sstream> #include <string> #include <vector> +#include <stack> #include <boost/format.hpp> @@ -52,6 +53,25 @@ using namespace llvm; using namespace std; using boost::format; +struct IfRecord { + IfRecord(LLVMContext& context, unsigned labelIndex, const Type* t) + : type(t) + , mergeBB(BasicBlock::Create(context, "endif")) + , thenBB(BasicBlock::Create(context, (format("then%1%") % labelIndex).str())) + , elseBB(BasicBlock::Create(context, (format("else%1%") % labelIndex).str())) + , thenV(NULL) + , elseV(NULL) + {} + + const Type* type; + BasicBlock* mergeBB; + BasicBlock* thenBB; + BasicBlock* elseBB; + + Value* thenV; + Value* elseV; +}; + /** LLVM Engine (Compiler and JIT) */ struct LLVMEngine : public Engine { LLVMEngine(); @@ -68,7 +88,10 @@ struct LLVMEngine : public Engine { CVal compileDot(CEnv& cenv, CVal tup, int32_t index); CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t); CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v); - CVal compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse); + void compileIfStart(CEnv& cenv, const AST* cond, const AST* type); + void compileIfThen(CEnv& cenv, CVal thenV); + void compileIfElse(CEnv& cenv, CVal elseV); + CVal compileIfEnd(CEnv& cenv); CVal compileLiteral(CEnv& cenv, const AST* lit); CVal compilePrimitive(CEnv& cenv, const ATuple* prim); CVal compileString(CEnv& cenv, const char* str); @@ -102,6 +125,9 @@ private: typedef std::map<const std::string, const Type*> CTypes; CTypes compiledTypes; + typedef std::stack<IfRecord*> IfStack; + IfStack if_stack; + unsigned labelIndex; }; @@ -165,7 +191,6 @@ LLVMEngine::llType(const AST* t, const char* name) CTypes::const_iterator i = compiledTypes.find(sym); if (i != compiledTypes.end()) return i->second; - } if (!AType::is_expr(t)) @@ -411,45 +436,72 @@ LLVMEngine::eraseFn(CEnv& cenv, CFunc f) llFunc(f)->eraseFromParent(); } -CVal -LLVMEngine::compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse) +void +LLVMEngine::compileIfStart(CEnv& cenv, const AST* cond, const AST* type) { - LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine()); - BasicBlock* mergeBB = BasicBlock::Create(context, "endif"); - Function* parent = engine->builder.GetInsertBlock()->getParent(); - BasicBlock* thenBB = BasicBlock::Create(context, (format("then%1%") % labelIndex).str()); - BasicBlock* nextBB = BasicBlock::Create(context, (format("else%1%") % labelIndex).str()); + LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine()); + Function* parent = engine->builder.GetInsertBlock()->getParent(); + IfRecord* rec = new IfRecord(context, ++labelIndex, llType(type)); - const AST* type = cenv.type(then); - - ++labelIndex; + if_stack.push(rec); - engine->builder.CreateCondBr(llVal(resp_compile(cenv, cond)), thenBB, nextBB); + engine->builder.CreateCondBr(llVal(resp_compile(cenv, cond)), + rec->thenBB, rec->elseBB); + + // Start then block + appendBlock(engine, parent, rec->thenBB); +} + +void +LLVMEngine::compileIfThen(CEnv& cenv, CVal thenV) +{ + LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine()); + Function* parent = engine->builder.GetInsertBlock()->getParent(); + IfRecord* rec = if_stack.top(); - // Emit then block for this condition - appendBlock(engine, parent, thenBB); - Value* thenV = llVal(resp_compile(cenv, then)); - engine->builder.CreateBr(mergeBB); + rec->thenV = llVal(thenV); - appendBlock(engine, parent, nextBB); + // Finish then block + engine->builder.CreateBr(rec->mergeBB); + rec->thenBB = engine->builder.GetInsertBlock(); - Value* elseV = NULL; - if (aelse) { - elseV = llVal(resp_compile(cenv, aelse)); - engine->builder.CreateBr(mergeBB); + // Start else block + appendBlock(engine, parent, rec->elseBB); +} + +void +LLVMEngine::compileIfElse(CEnv& cenv, CVal elseV) +{ + LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine()); + IfRecord* rec = if_stack.top(); + + rec->elseV = llVal(elseV); + + if (elseV) { + engine->builder.CreateBr(rec->mergeBB); } else { engine->builder.CreateUnreachable(); } - // Emit end of final else block - BasicBlock* elseBB = engine->builder.GetInsertBlock(); + rec->elseBB = engine->builder.GetInsertBlock(); +} + +CVal +LLVMEngine::compileIfEnd(CEnv& cenv) +{ + LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine()); + Function* parent = engine->builder.GetInsertBlock()->getParent(); + IfRecord* rec = if_stack.top(); // Emit merge block (Phi node) - appendBlock(engine, parent, mergeBB); - PHINode* pn = engine->builder.CreatePHI(llType(type), "ifval"); - pn->addIncoming(thenV, thenBB); - if (elseV) - pn->addIncoming(elseV, elseBB); + appendBlock(engine, parent, rec->mergeBB); + PHINode* pn = engine->builder.CreatePHI(rec->type, "ifval"); + pn->addIncoming(rec->thenV, rec->thenBB); + if (rec->elseV) + pn->addIncoming(rec->elseV, rec->elseBB); + + if_stack.pop(); + delete rec; return pn; } |