aboutsummaryrefslogtreecommitdiffstats
path: root/src/llvm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/llvm.cpp')
-rw-r--r--src/llvm.cpp66
1 files changed, 38 insertions, 28 deletions
diff --git a/src/llvm.cpp b/src/llvm.cpp
index 30927d4..bb30cbc 100644
--- a/src/llvm.cpp
+++ b/src/llvm.cpp
@@ -52,14 +52,15 @@ llType(const AType* t)
if (t->head()->str() == "Float") return Type::FloatTy;
throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'");
} else if (t->kind == AType::EXPR && t->head()->str() == "Fn") {
- const AType* retT = t->at(2)->as<const AType*>();
+ AType::const_iterator i = t->begin();
+ const ATuple* protT = (*++i)->to<const ATuple*>();
+ const AType* retT = (*i)->as<const AType*>();
if (!llType(retT))
return NULL;
vector<const Type*> cprot;
- const ATuple* prot = t->at(1)->to<const ATuple*>();
- for (size_t i = 0; i < prot->size(); ++i) {
- const AType* at = prot->at(i)->to<const AType*>();
+ FOREACHP(ATuple::const_iterator, i, protT) {
+ const AType* at = (*i)->to<const AType*>();
const Type* lt = llType(at);
if (!lt)
return NULL;
@@ -126,8 +127,8 @@ struct LLVMEngine : public Engine {
// Set argument names in generated code
Function::arg_iterator a = f->arg_begin();
if (!argNames.empty())
- for (size_t i = 0; i != argsT.size(); ++a, ++i)
- a->setName(argNames.at(i));
+ for (vector<string>::const_iterator i = argNames.begin(); i != argNames.end(); ++a, ++i)
+ a->setName(*i);
BasicBlock* bb = BasicBlock::Create("entry", f);
builder.SetInsertPoint(bb);
@@ -243,16 +244,16 @@ LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)
if (f)
return f;
- ATuple* protT = thisType->at(1)->as<ATuple*>();
+ ATuple* protT = thisType->prot();
vector<string> argNames;
- for (size_t i = 0; i < fn->prot()->size(); ++i)
- argNames.push_back(fn->prot()->at(i)->str());
+ for (ATuple::const_iterator i = fn->prot()->begin(); i != fn->prot()->end(); ++i)
+ argNames.push_back((*i)->str());
// Write function declaration
const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name;
f = llFunc(cenv.engine()->startFunction(cenv, name,
- thisType->at(thisType->size()-1)->to<AType*>(),
+ thisType->last()->to<AType*>(),
*protT, argNames));
cenv.push();
@@ -261,10 +262,10 @@ LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)
// Bind argument values in CEnv
vector<Value*> args;
- AFn::const_iterator p = fn->prot()->begin();
- size_t i = 0;
- for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) {
- AType* t = protT->at(i)->as<AType*>();
+ AFn::const_iterator p = fn->prot()->begin();
+ ATuple::const_iterator pT = protT->begin();
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++pT) {
+ AType* t = (*pT)->as<AType*>();
const Type* lt = llType(t);
THROW_IF(!lt, fn->loc, "untyped parameter\n");
cenv.def((*p)->as<ASymbol*>(), *p, t, &*a);
@@ -274,9 +275,9 @@ LLVMEngine::compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)
try {
fn->impls.push_back(make_pair(thisType, f));
CVal retVal = NULL;
- for (size_t i = 2; i < fn->size(); ++i)
- retVal = fn->at(i)->compile(cenv);
- cenv.engine()->finishFunction(cenv, f, cenv.type(fn->at(fn->size() - 1)), retVal);
+ for (AFn::iterator i = fn->begin() + 2; i != fn->end(); ++i)
+ retVal = (*i)->compile(cenv);
+ cenv.engine()->finishFunction(cenv, f, cenv.type(fn->last()), retVal);
} catch (Error& e) {
f->eraseFromParent(); // Error reading body, remove function
cenv.pop();
@@ -296,28 +297,35 @@ LLVMEngine::compileIf(CEnv& cenv, AIf* aif)
BasicBlock* mergeBB = BasicBlock::Create("endif");
BasicBlock* nextBB = NULL;
Branches branches;
- for (size_t i = 1; i < aif->size() - 1; i += 2) {
- Value* condV = llVal(aif->at(i)->compile(cenv));
- BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str());
+ size_t idx = 1;
+ for (AIf::iterator i = aif->begin() + 1; ; ++i, idx += 2) {
+ AIf::iterator next = i;
+ if (++next == aif->end())
+ break;
- nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str());
+ Value* condV = llVal((*i)->compile(cenv));
+ BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((idx+1)/2)).str());
+
+ nextBB = BasicBlock::Create((format("else%1%") % ((idx+1)/2)).str());
engine->builder.CreateCondBr(condV, thenBB, nextBB);
// Emit then block for this condition
parent->getBasicBlockList().push_back(thenBB);
engine->builder.SetInsertPoint(thenBB);
- Value* thenV = llVal(aif->at(i + 1)->compile(cenv));
+ Value* thenV = llVal((*next)->compile(cenv));
engine->builder.CreateBr(mergeBB);
branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock()));
parent->getBasicBlockList().push_back(nextBB);
engine->builder.SetInsertPoint(nextBB);
+
+ i = next; // jump 2 each iteration (to the next predicate)
}
// Emit final else block
engine->builder.SetInsertPoint(nextBB);
- Value* elseV = llVal(aif->at(aif->size() - 1)->compile(cenv));
+ Value* elseV = llVal(aif->last()->compile(cenv));
engine->builder.CreateBr(mergeBB);
branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock()));
@@ -335,10 +343,12 @@ LLVMEngine::compileIf(CEnv& cenv, AIf* aif)
CVal
LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim)
{
+ APrimitive::iterator i = prim->begin();
+
LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
- Value* a = llVal(prim->at(1)->compile(cenv));
- Value* b = llVal(prim->at(2)->compile(cenv));
- bool isFloat = cenv.type(prim->at(1))->str() == "Float";
+ bool isFloat = cenv.type(*++i)->str() == "Float";
+ Value* a = llVal((*i++)->compile(cenv));
+ Value* b = llVal((*i++)->compile(cenv));
const string n = prim->head()->to<ASymbol*>()->str();
// Binary arithmetic operations
@@ -353,8 +363,8 @@ LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim)
if (n == "%") op = isFloat ? Instruction::FRem : Instruction::SRem;
if (op != 0) {
Value* val = engine->builder.CreateBinOp(op, a, b);
- for (size_t i = 3; i < prim->size(); ++i)
- val = engine->builder.CreateBinOp(op, val, llVal(prim->at(i)->compile(cenv)));
+ while (i != prim->end())
+ val = engine->builder.CreateBinOp(op, val, llVal((*i++)->compile(cenv)));
return val;
}