diff options
Diffstat (limited to 'src/llvm.cpp')
-rw-r--r-- | src/llvm.cpp | 66 |
1 files changed, 38 insertions, 28 deletions
diff --git a/src/llvm.cpp b/src/llvm.cpp index 30927d4..bb30cbc 100644 --- a/src/llvm.cpp +++ b/src/llvm.cpp @@ -52,14 +52,15 @@ llType(const AType* t) if (t->head()->str() == "Float") return Type::FloatTy; throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'"); } else if (t->kind == AType::EXPR && t->head()->str() == "Fn") { - const AType* retT = t->at(2)->as<const AType*>(); + AType::const_iterator i = t->begin(); + const ATuple* protT = (*++i)->to<const ATuple*>(); + const AType* retT = (*i)->as<const AType*>(); if (!llType(retT)) return NULL; vector<const Type*> cprot; - const ATuple* prot = t->at(1)->to<const ATuple*>(); - for (size_t i = 0; i < prot->size(); ++i) { - const AType* at = prot->at(i)->to<const AType*>(); + FOREACHP(ATuple::const_iterator, i, protT) { + const AType* at = (*i)->to<const AType*>(); const Type* lt = llType(at); if (!lt) return NULL; @@ -126,8 +127,8 @@ struct LLVMEngine : public Engine { // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); if (!argNames.empty()) - for (size_t i = 0; i != argsT.size(); ++a, ++i) - a->setName(argNames.at(i)); + for (vector<string>::const_iterator i = argNames.begin(); i != argNames.end(); ++a, ++i) + a->setName(*i); BasicBlock* bb = BasicBlock::Create("entry", f); builder.SetInsertPoint(bb); @@ -243,16 +244,16 @@ LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) if (f) return f; - ATuple* protT = thisType->at(1)->as<ATuple*>(); + ATuple* protT = thisType->prot(); vector<string> argNames; - for (size_t i = 0; i < fn->prot()->size(); ++i) - argNames.push_back(fn->prot()->at(i)->str()); + for (ATuple::const_iterator i = fn->prot()->begin(); i != fn->prot()->end(); ++i) + argNames.push_back((*i)->str()); // Write function declaration const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name; f = llFunc(cenv.engine()->startFunction(cenv, name, - thisType->at(thisType->size()-1)->to<AType*>(), + thisType->last()->to<AType*>(), *protT, argNames)); cenv.push(); @@ -261,10 +262,10 @@ LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) // Bind argument values in CEnv vector<Value*> args; - AFn::const_iterator p = fn->prot()->begin(); - size_t i = 0; - for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) { - AType* t = protT->at(i)->as<AType*>(); + AFn::const_iterator p = fn->prot()->begin(); + ATuple::const_iterator pT = protT->begin(); + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++pT) { + AType* t = (*pT)->as<AType*>(); const Type* lt = llType(t); THROW_IF(!lt, fn->loc, "untyped parameter\n"); cenv.def((*p)->as<ASymbol*>(), *p, t, &*a); @@ -274,9 +275,9 @@ LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT) try { fn->impls.push_back(make_pair(thisType, f)); CVal retVal = NULL; - for (size_t i = 2; i < fn->size(); ++i) - retVal = fn->at(i)->compile(cenv); - cenv.engine()->finishFunction(cenv, f, cenv.type(fn->at(fn->size() - 1)), retVal); + for (AFn::iterator i = fn->begin() + 2; i != fn->end(); ++i) + retVal = (*i)->compile(cenv); + cenv.engine()->finishFunction(cenv, f, cenv.type(fn->last()), retVal); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function cenv.pop(); @@ -296,28 +297,35 @@ LLVMEngine::compileIf(CEnv& cenv, AIf* aif) BasicBlock* mergeBB = BasicBlock::Create("endif"); BasicBlock* nextBB = NULL; Branches branches; - for (size_t i = 1; i < aif->size() - 1; i += 2) { - Value* condV = llVal(aif->at(i)->compile(cenv)); - BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str()); + size_t idx = 1; + for (AIf::iterator i = aif->begin() + 1; ; ++i, idx += 2) { + AIf::iterator next = i; + if (++next == aif->end()) + break; - nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str()); + Value* condV = llVal((*i)->compile(cenv)); + BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((idx+1)/2)).str()); + + nextBB = BasicBlock::Create((format("else%1%") % ((idx+1)/2)).str()); engine->builder.CreateCondBr(condV, thenBB, nextBB); // Emit then block for this condition parent->getBasicBlockList().push_back(thenBB); engine->builder.SetInsertPoint(thenBB); - Value* thenV = llVal(aif->at(i + 1)->compile(cenv)); + Value* thenV = llVal((*next)->compile(cenv)); engine->builder.CreateBr(mergeBB); branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock())); parent->getBasicBlockList().push_back(nextBB); engine->builder.SetInsertPoint(nextBB); + + i = next; // jump 2 each iteration (to the next predicate) } // Emit final else block engine->builder.SetInsertPoint(nextBB); - Value* elseV = llVal(aif->at(aif->size() - 1)->compile(cenv)); + Value* elseV = llVal(aif->last()->compile(cenv)); engine->builder.CreateBr(mergeBB); branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock())); @@ -335,10 +343,12 @@ LLVMEngine::compileIf(CEnv& cenv, AIf* aif) CVal LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim) { + APrimitive::iterator i = prim->begin(); + LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine()); - Value* a = llVal(prim->at(1)->compile(cenv)); - Value* b = llVal(prim->at(2)->compile(cenv)); - bool isFloat = cenv.type(prim->at(1))->str() == "Float"; + bool isFloat = cenv.type(*++i)->str() == "Float"; + Value* a = llVal((*i++)->compile(cenv)); + Value* b = llVal((*i++)->compile(cenv)); const string n = prim->head()->to<ASymbol*>()->str(); // Binary arithmetic operations @@ -353,8 +363,8 @@ LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim) if (n == "%") op = isFloat ? Instruction::FRem : Instruction::SRem; if (op != 0) { Value* val = engine->builder.CreateBinOp(op, a, b); - for (size_t i = 3; i < prim->size(); ++i) - val = engine->builder.CreateBinOp(op, val, llVal(prim->at(i)->compile(cenv))); + while (i != prim->end()) + val = engine->builder.CreateBinOp(op, val, llVal((*i++)->compile(cenv))); return val; } |