aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/compile.cpp12
-rw-r--r--src/llvm.cpp110
2 files changed, 92 insertions, 30 deletions
diff --git a/src/compile.cpp b/src/compile.cpp
index 44017ec..ba8cc5e 100644
--- a/src/compile.cpp
+++ b/src/compile.cpp
@@ -167,11 +167,18 @@ compile_fn(CEnv& cenv, const ATuple* fn) throw()
static CVal
compile_if(CEnv& cenv, const ATuple* aif) throw()
{
+ const AST* cond = aif->list_ref(1);
+ const AST* then = aif->list_ref(2);
const AST* aelse = NULL;
if (*aif->list_last() != *cenv.penv.sym("__unreachable"))
aelse = aif->list_ref(3);
- return cenv.engine()->compileIf(cenv, aif->list_ref(1), aif->list_ref(2), aelse);
+ const AST* type = cenv.type(then);
+
+ cenv.engine()->compileIfStart(cenv, cond, type);
+ cenv.engine()->compileIfThen(cenv, resp_compile(cenv, then));
+ cenv.engine()->compileIfElse(cenv, resp_compile(cenv, aelse));
+ return cenv.engine()->compileIfEnd(cenv);
}
static CVal
@@ -214,6 +221,9 @@ compile_call(CEnv& cenv, const ATuple* call) throw()
CVal
resp_compile(CEnv& cenv, const AST* ast) throw()
{
+ if (!ast)
+ return NULL;
+
switch (ast->tag()) {
case T_UNKNOWN:
return NULL;
diff --git a/src/llvm.cpp b/src/llvm.cpp
index 2c80ba7..06a895a 100644
--- a/src/llvm.cpp
+++ b/src/llvm.cpp
@@ -26,6 +26,7 @@
#include <sstream>
#include <string>
#include <vector>
+#include <stack>
#include <boost/format.hpp>
@@ -52,6 +53,25 @@ using namespace llvm;
using namespace std;
using boost::format;
+struct IfRecord {
+ IfRecord(LLVMContext& context, unsigned labelIndex, const Type* t)
+ : type(t)
+ , mergeBB(BasicBlock::Create(context, "endif"))
+ , thenBB(BasicBlock::Create(context, (format("then%1%") % labelIndex).str()))
+ , elseBB(BasicBlock::Create(context, (format("else%1%") % labelIndex).str()))
+ , thenV(NULL)
+ , elseV(NULL)
+ {}
+
+ const Type* type;
+ BasicBlock* mergeBB;
+ BasicBlock* thenBB;
+ BasicBlock* elseBB;
+
+ Value* thenV;
+ Value* elseV;
+};
+
/** LLVM Engine (Compiler and JIT) */
struct LLVMEngine : public Engine {
LLVMEngine();
@@ -68,7 +88,10 @@ struct LLVMEngine : public Engine {
CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t);
CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v);
- CVal compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse);
+ void compileIfStart(CEnv& cenv, const AST* cond, const AST* type);
+ void compileIfThen(CEnv& cenv, CVal thenV);
+ void compileIfElse(CEnv& cenv, CVal elseV);
+ CVal compileIfEnd(CEnv& cenv);
CVal compileLiteral(CEnv& cenv, const AST* lit);
CVal compilePrimitive(CEnv& cenv, const ATuple* prim);
CVal compileString(CEnv& cenv, const char* str);
@@ -102,6 +125,9 @@ private:
typedef std::map<const std::string, const Type*> CTypes;
CTypes compiledTypes;
+ typedef std::stack<IfRecord*> IfStack;
+ IfStack if_stack;
+
unsigned labelIndex;
};
@@ -165,7 +191,6 @@ LLVMEngine::llType(const AST* t, const char* name)
CTypes::const_iterator i = compiledTypes.find(sym);
if (i != compiledTypes.end())
return i->second;
-
}
if (!AType::is_expr(t))
@@ -411,45 +436,72 @@ LLVMEngine::eraseFn(CEnv& cenv, CFunc f)
llFunc(f)->eraseFromParent();
}
-CVal
-LLVMEngine::compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse)
+void
+LLVMEngine::compileIfStart(CEnv& cenv, const AST* cond, const AST* type)
{
- LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
- BasicBlock* mergeBB = BasicBlock::Create(context, "endif");
- Function* parent = engine->builder.GetInsertBlock()->getParent();
- BasicBlock* thenBB = BasicBlock::Create(context, (format("then%1%") % labelIndex).str());
- BasicBlock* nextBB = BasicBlock::Create(context, (format("else%1%") % labelIndex).str());
+ LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
+ Function* parent = engine->builder.GetInsertBlock()->getParent();
+ IfRecord* rec = new IfRecord(context, ++labelIndex, llType(type));
- const AST* type = cenv.type(then);
-
- ++labelIndex;
+ if_stack.push(rec);
- engine->builder.CreateCondBr(llVal(resp_compile(cenv, cond)), thenBB, nextBB);
+ engine->builder.CreateCondBr(llVal(resp_compile(cenv, cond)),
+ rec->thenBB, rec->elseBB);
+
+ // Start then block
+ appendBlock(engine, parent, rec->thenBB);
+}
+
+void
+LLVMEngine::compileIfThen(CEnv& cenv, CVal thenV)
+{
+ LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
+ Function* parent = engine->builder.GetInsertBlock()->getParent();
+ IfRecord* rec = if_stack.top();
- // Emit then block for this condition
- appendBlock(engine, parent, thenBB);
- Value* thenV = llVal(resp_compile(cenv, then));
- engine->builder.CreateBr(mergeBB);
+ rec->thenV = llVal(thenV);
- appendBlock(engine, parent, nextBB);
+ // Finish then block
+ engine->builder.CreateBr(rec->mergeBB);
+ rec->thenBB = engine->builder.GetInsertBlock();
- Value* elseV = NULL;
- if (aelse) {
- elseV = llVal(resp_compile(cenv, aelse));
- engine->builder.CreateBr(mergeBB);
+ // Start else block
+ appendBlock(engine, parent, rec->elseBB);
+}
+
+void
+LLVMEngine::compileIfElse(CEnv& cenv, CVal elseV)
+{
+ LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
+ IfRecord* rec = if_stack.top();
+
+ rec->elseV = llVal(elseV);
+
+ if (elseV) {
+ engine->builder.CreateBr(rec->mergeBB);
} else {
engine->builder.CreateUnreachable();
}
- // Emit end of final else block
- BasicBlock* elseBB = engine->builder.GetInsertBlock();
+ rec->elseBB = engine->builder.GetInsertBlock();
+}
+
+CVal
+LLVMEngine::compileIfEnd(CEnv& cenv)
+{
+ LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
+ Function* parent = engine->builder.GetInsertBlock()->getParent();
+ IfRecord* rec = if_stack.top();
// Emit merge block (Phi node)
- appendBlock(engine, parent, mergeBB);
- PHINode* pn = engine->builder.CreatePHI(llType(type), "ifval");
- pn->addIncoming(thenV, thenBB);
- if (elseV)
- pn->addIncoming(elseV, elseBB);
+ appendBlock(engine, parent, rec->mergeBB);
+ PHINode* pn = engine->builder.CreatePHI(rec->type, "ifval");
+ pn->addIncoming(rec->thenV, rec->thenBB);
+ if (rec->elseV)
+ pn->addIncoming(rec->elseV, rec->elseBB);
+
+ if_stack.pop();
+ delete rec;
return pn;
}