diff options
Diffstat (limited to 'src/llvm.cpp')
-rw-r--r-- | src/llvm.cpp | 74 |
1 files changed, 67 insertions, 7 deletions
diff --git a/src/llvm.cpp b/src/llvm.cpp index 8b4deaa..24846cc 100644 --- a/src/llvm.cpp +++ b/src/llvm.cpp @@ -115,6 +115,7 @@ struct LLVMEngine : public Engine { return PointerType::get(FunctionType::get(llType(retT), cprot, false), 0); } else if (t->kind == AType::EXPR && isupper(t->head()->str()[0])) { vector<const Type*> ctypes; + ctypes.push_back(PointerType::get(Type::getInt8Ty(context), NULL)); // RTTI for (AType::const_iterator i = t->begin() + 1; i != t->end(); ++i) { const Type* lt = llType((*i)->to<const AType*>()); if (!lt) @@ -191,12 +192,13 @@ struct LLVMEngine : public Engine { return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end()); } - CVal compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields); + CVal compileTup(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields); CVal compileDot(CEnv& cenv, CVal tup, int32_t index); CVal compileLiteral(CEnv& cenv, const AST* lit); CVal compileString(CEnv& cenv, const char* str); CVal compilePrimitive(CEnv& cenv, const APrimitive* prim); CVal compileIf(CEnv& cenv, const AIf* aif); + CVal compileMatch(CEnv& cenv, const AMatch* match); CVal compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal val); CVal getGlobal(CEnv& cenv, const string& sym, CVal val); @@ -271,10 +273,10 @@ bitsToBytes(size_t bits) } CVal -LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields) +LLVMEngine::compileTup(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields) { // Find size of memory required - size_t s = 0; + size_t s = engine->getTargetData()->getTypeSizeInBits(PointerType::get(Type::getInt8Ty(context), NULL)); assert(type->begin() != type->end()); for (AType::const_iterator i = type->begin() + 1; i != type->end(); ++i) s += engine->getTargetData()->getTypeSizeInBits(llType((*i)->as<AType*>())); @@ -285,7 +287,9 @@ LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields Value* structPtr = builder.CreateBitCast(mem, llType(type), "tup"); // Set struct fields - size_t i = 0; + if (rtti) + builder.CreateStore((Value*)rtti, builder.CreateStructGEP(structPtr, 0, "rtti")); + size_t i = 1; for (vector<CVal>::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i) { builder.CreateStore(llVal(*f), builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str())); @@ -297,7 +301,7 @@ LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields CVal LLVMEngine::compileDot(CEnv& cenv, CVal tup, int32_t index) { - Value* ptr = builder.CreateStructGEP(llVal(tup), index, "dotPtr"); + Value* ptr = builder.CreateStructGEP(llVal(tup), index + 1, "dotPtr"); // +1 to skip RTTI return builder.CreateLoad(ptr, 0, "dotVal"); } @@ -382,7 +386,6 @@ LLVMEngine::compileIf(CEnv& cenv, const AIf* aif) } // Emit final else block - engine->builder.SetInsertPoint(nextBB); Value* elseV = llVal(aif->last()->compile(cenv)); engine->builder.CreateBr(mergeBB); branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock())); @@ -399,6 +402,61 @@ LLVMEngine::compileIf(CEnv& cenv, const AIf* aif) } CVal +LLVMEngine::compileMatch(CEnv& cenv, const AMatch* match) +{ + typedef vector< pair<Value*, BasicBlock*> > Branches; + Value* matchee = llVal((*(match->begin() + 1))->compile(cenv)); + Value* rttiPtr = builder.CreateStructGEP(matchee, 0, "matchRTTIPtr"); + Value* rtti = builder.CreateLoad(rttiPtr, 0, "matchRTTI"); + + LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine()); + Function* parent = engine->builder.GetInsertBlock()->getParent(); + BasicBlock* mergeBB = BasicBlock::Create(context, "endmatch"); + BasicBlock* nextBB = NULL; + Branches branches; + + size_t idx = 1; + for (AMatch::const_iterator i = match->begin() + 2; i != match->end(); ++idx) { + const AST* pat = *i++; + const AST* body = *i++; + const ASymbol* sym = pat->to<const ATuple*>()->head()->as<const ASymbol*>(); + const AType* patT = tup<AType>(Cursor(), const_cast<ASymbol*>(sym), 0); + + Value* typeV = llVal(patT->compile(cenv)); + Value* condV = engine->builder.CreateICmp(CmpInst::ICMP_EQ, rtti, typeV); + BasicBlock* thenBB = BasicBlock::Create(context, (format("case%1%") % ((idx+1)/2)).str()); + + nextBB = BasicBlock::Create(context, (format("otherwise%1%") % ((idx+1)/2)).str()); + + engine->builder.CreateCondBr(condV, thenBB, nextBB); + + // Emit then block for this condition + parent->getBasicBlockList().push_back(thenBB); + engine->builder.SetInsertPoint(thenBB); + Value* thenV = llVal(body->compile(cenv)); + engine->builder.CreateBr(mergeBB); + branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock())); + + parent->getBasicBlockList().push_back(nextBB); + engine->builder.SetInsertPoint(nextBB); + } + + // Emit final else block (FIXME: n/a, what to do here?) + engine->builder.CreateBr(mergeBB); + branches.push_back(make_pair(Constant::getNullValue(llType(cenv.type(match))), engine->builder.GetInsertBlock())); + + // Emit merge block (Phi node) + parent->getBasicBlockList().push_back(mergeBB); + engine->builder.SetInsertPoint(mergeBB); + PHINode* pn = engine->builder.CreatePHI(llType(cenv.type(match)), "mergeval"); + + FOREACH(Branches::iterator, i, branches) + pn->addIncoming(i->first, i->second); + + return pn; +} + +CVal LLVMEngine::compilePrimitive(CEnv& cenv, const APrimitive* prim) { APrimitive::const_iterator i = prim->begin(); @@ -451,7 +509,9 @@ LLVMEngine::compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal GlobalVariable* global = new GlobalVariable(*module, llType(type), false, GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), sym); - engine->builder.CreateStore(llVal(val), global); + Value* valPtr = builder.CreateBitCast(llVal(val), llType(type), "globalPtr"); + + engine->builder.CreateStore(valPtr, global); return global; } |