From f8e16165e6666fceaac66c777851a0f99ba8f5fc Mon Sep 17 00:00:00 2001 From: David Robillard Date: Sat, 20 Jun 2009 18:05:24 +0000 Subject: Better abstraction for low-level code generation stuff (make repl and eval not backend specific). git-svn-id: http://svn.drobilla.net/resp/tuplr@132 ad02d1e2-f140-0410-9f75-f8b11f17cedd --- llvm.cpp | 162 ++++++++++++++++++++++++++++++++------------------------------- 1 file changed, 82 insertions(+), 80 deletions(-) diff --git a/llvm.cpp b/llvm.cpp index baa7ed3..f472893 100644 --- a/llvm.cpp +++ b/llvm.cpp @@ -35,39 +35,28 @@ using namespace llvm; using namespace std; using boost::format; -inline Value* LLVal(CValue v) { return static_cast(v); } -inline Function* LLFunc(CFunction f) { return static_cast(f); } - -struct LLVMEngine { - LLVMEngine(); - Module* module; - ExecutionEngine* engine; - IRBuilder<> builder; -}; +inline Value* llVal(CValue v) { return static_cast(v); } +inline Function* llFunc(CFunction f) { return static_cast(f); } static const Type* -lltype(const AType* t) +llType(const AType* t) { - switch (t->kind) { - case AType::VAR: - throw Error(t->loc, (format("non-compilable type `%1%'") % t->str()).str()); - return NULL; - case AType::PRIM: + if (t->kind == AType::PRIM) { if (t->at(0)->str() == "Bool") return Type::Int1Ty; if (t->at(0)->str() == "Int") return Type::Int32Ty; if (t->at(0)->str() == "Float") return Type::FloatTy; throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'"); - case AType::EXPR: - if (t->at(0)->str() == "Pair") { - vector types; - for (size_t i = 1; i < t->size(); ++i) - types.push_back(lltype(t->at(i)->to())); - return PointerType::get(StructType::get(types, false), 0); - } } - return NULL; // not reached + return NULL; // non-primitive type } +struct LLVMEngine { + LLVMEngine(); + Module* module; + ExecutionEngine* engine; + IRBuilder<> builder; +}; + static LLVMEngine* llengine(CEnv& cenv) { @@ -148,33 +137,33 @@ LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)) LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)) LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)) -static Function* -compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& protT, - Cursor loc, const vector argNames=vector()) +CFunction +startFunction(CEnv& cenv, const std::string& name, const AType* retT, const ATuple& argsT, + const vector argNames=vector()) { Function::LinkageTypes linkage = Function::ExternalLinkage; vector cprot; - for (size_t i = 0; i < protT.size(); ++i) { - AType* at = protT.at(i)->as(); - THROW_IF(!lltype(at), protT.at(i)->loc, "function parameter is untyped") - cprot.push_back(lltype(at)); + for (size_t i = 0; i < argsT.size(); ++i) { + AType* at = argsT.at(i)->as(); + THROW_IF(!llType(at), Cursor(), "function parameter is untyped") + cprot.push_back(llType(at)); } - THROW_IF(!retT, loc, "function return is untyped"); - FunctionType* fT = FunctionType::get(static_cast(retT), cprot, false); + THROW_IF(!llType(retT), Cursor(), "function return is untyped"); + FunctionType* fT = FunctionType::get(llType(retT), cprot, false); Function* f = Function::Create(fT, linkage, name, llengine(cenv)->module); if (f->getName() != name) { cenv.out << "DIFFERENT NAME: " << f->getName() << endl; /*f->eraseFromParent(); - throw Error(loc, (format("function `%1%' redefined") % name).str());*/ + throw Error(Cursor(), (format("function `%1%' redefined") % name).str());*/ } // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); if (!argNames.empty()) - for (size_t i = 0; i != protT.size(); ++a, ++i) + for (size_t i = 0; i != argsT.size(); ++a, ++i) a->setName(argNames.at(i)); BasicBlock* bb = BasicBlock::Create("entry", f); @@ -183,6 +172,21 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu return f; } +void +finishFunction(CEnv& cenv, CFunction f, CValue ret) +{ + Value* retVal = llVal(ret); + llengine(cenv)->builder.CreateRet(retVal); + cenv.optimise(llFunc(f)); +} + +void +eraseFunction(CEnv& cenv, CFunction f) +{ + if (f) + llFunc(f)->eraseFromParent(); +} + /*************************************************************************** * Code Generation * @@ -248,11 +252,9 @@ AClosure::liftCall(CEnv& cenv, const AType& argsT) // Write function declaration const string name = (this->name == "") ? cenv.gensym("_fn") : this->name; - Function* f = compileFunction(cenv, name, - lltype(thisType->at(thisType->size()-1)->to()), - *protT, loc, argNames); - - llvm::IRBuilder<>& builder = llengine(cenv)->builder; + Function* f = llFunc(startFunction(cenv, name, + thisType->at(thisType->size()-1)->to(), + *protT, argNames)); cenv.push(); Subst oldSubst = cenv.tsubst; @@ -273,7 +275,7 @@ AClosure::liftCall(CEnv& cenv, const AType& argsT) size_t i = 0; for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) { AType* t = protT->at(i)->as(); - const Type* lt = lltype(t); + const Type* lt = llType(t); THROW_IF(!lt, loc, "untyped parameter\n"); cenv.def((*p)->as(), *p, t, &*a); #ifdef EXPLICIT_STACK_FRAMES @@ -288,7 +290,7 @@ AClosure::liftCall(CEnv& cenv, const AType& argsT) for (size_t i = 0; i < size(); ++i) { ADefinition* def = at(i)->to(); if (def) { - const Type* lt = lltype(cenv.type(def->at(2))); + const Type* lt = llType(cenv.type(def->at(2))); THROW_IF(!lt, loc, "untyped definition\n"); types.push_back(lt); s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8)); @@ -302,7 +304,7 @@ AClosure::liftCall(CEnv& cenv, const AType& argsT) Value* tag = ConstantInt::get(Type::Int8Ty, (uint8_t)GC::TAG_FRAME); Value* frameSize = ConstantInt::get(Type::Int32Ty, s / 8); - Value* frame = builder.CreateCall2(LLVal(cenv.alloc), frameSize, tag, "frame"); + Value* frame = builder.CreateCall2(llVal(cenv.alloc), frameSize, tag, "frame"); Value* framePtr = builder.CreateBitCast(frame, framePtrT, "frameptr"); @@ -322,8 +324,7 @@ AClosure::liftCall(CEnv& cenv, const AType& argsT) CValue retVal = NULL; for (size_t i = 2; i < size(); ++i) retVal = cenv.compile(at(i)); - builder.CreateRet(LLVal(retVal)); // Finish function - cenv.optimise(LLFunc(f)); + finishFunction(cenv, f, retVal); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function cenv.pop(); @@ -372,7 +373,7 @@ ACall::compile(CEnv& cenv) vector types; for (size_t i = 1; i < size(); ++i) { protT.push_back(cenv.type(at(i))); - types.push_back(lltype(cenv.type(at(i)))); + types.push_back(llType(cenv.type(at(i)))); } TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(c); @@ -383,7 +384,7 @@ ACall::compile(CEnv& cenv) vector params(size() - 1); for (size_t i = 0; i < types.size(); ++i) - params[i] = LLVal(cenv.compile(at(i+1))); + params[i] = llVal(cenv.compile(at(i+1))); return llengine(cenv)->builder.CreateCall(f, params.begin(), params.end()); } @@ -418,7 +419,7 @@ AIf::compile(CEnv& cenv) BasicBlock* nextBB = NULL; Branches branches; for (size_t i = 1; i < size() - 1; i += 2) { - Value* condV = LLVal(cenv.compile(at(i))); + Value* condV = llVal(cenv.compile(at(i))); BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str()); nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str()); @@ -428,7 +429,7 @@ AIf::compile(CEnv& cenv) // Emit then block for this condition parent->getBasicBlockList().push_back(thenBB); llengine(cenv)->builder.SetInsertPoint(thenBB); - Value* thenV = LLVal(cenv.compile(at(i+1))); + Value* thenV = llVal(cenv.compile(at(i+1))); llengine(cenv)->builder.CreateBr(mergeBB); branches.push_back(make_pair(thenV, llengine(cenv)->builder.GetInsertBlock())); @@ -438,14 +439,14 @@ AIf::compile(CEnv& cenv) // Emit final else block llengine(cenv)->builder.SetInsertPoint(nextBB); - Value* elseV = LLVal(cenv.compile(at(size() - 1))); + Value* elseV = llVal(cenv.compile(at(size() - 1))); llengine(cenv)->builder.CreateBr(mergeBB); branches.push_back(make_pair(elseV, llengine(cenv)->builder.GetInsertBlock())); // Emit merge block (Phi node) parent->getBasicBlockList().push_back(mergeBB); llengine(cenv)->builder.SetInsertPoint(mergeBB); - PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.type(this)), "ifval"); + PHINode* pn = llengine(cenv)->builder.CreatePHI(llType(cenv.type(this)), "ifval"); FOREACH(Branches::iterator, i, branches) pn->addIncoming(i->first, i->second); @@ -456,8 +457,8 @@ AIf::compile(CEnv& cenv) CValue APrimitive::compile(CEnv& cenv) { - Value* a = LLVal(cenv.compile(at(1))); - Value* b = LLVal(cenv.compile(at(2))); + Value* a = llVal(cenv.compile(at(1))); + Value* b = llVal(cenv.compile(at(2))); bool isFloat = cenv.type(at(1))->str() == "Float"; const string n = at(0)->to()->str(); @@ -474,7 +475,7 @@ APrimitive::compile(CEnv& cenv) if (op != 0) { Value* val = llengine(cenv)->builder.CreateBinOp(op, a, b); for (size_t i = 3; i < size(); ++i) - val = llengine(cenv)->builder.CreateBinOp(op, val, LLVal(cenv.compile(at(i)))); + val = llengine(cenv)->builder.CreateBinOp(op, val, llVal(cenv.compile(at(i)))); return val; } @@ -502,14 +503,19 @@ APrimitive::compile(CEnv& cenv) ***************************************************************************/ const string -call(AType* retT, void* fp) +call(CEnv& cenv, CFunction f, AType* retT) { + void* fp = llengine(cenv)->engine->getPointerToFunction(llFunc(f)); + const Type* t = llType(retT); + THROW_IF(!fp, Cursor(), "unable to get function pointer"); + THROW_IF(!t, Cursor(), "function with non-primitive return type called"); + std::stringstream ss; - if (lltype(retT) == Type::Int32Ty) + if (t == Type::Int32Ty) ss << ((int32_t (*)())fp)(); - else if (lltype(retT) == Type::FloatTy) + else if (t == Type::FloatTy) ss << showpoint << ((float (*)())fp)(); - else if (lltype(retT) == Type::Int1Ty) + else if (t == Type::Int1Ty) ss << (((bool (*)())fp)() ? "#t" : "#f"); else ss << ((void* (*)())fp)(); @@ -548,23 +554,19 @@ eval(CEnv& cenv, const string& name, istream& is) } } - const Type* ctype = lltype(resultType); - THROW_IF(!ctype, cursor, "body has non-compilable type") + THROW_IF(!llType(resultType), cursor, "body has non-compilable type") // Create function for top-level of program - Function* f = compileFunction(cenv, "main", ctype, ATuple(cursor), cursor); + CFunction f = startFunction(cenv, "main", resultType, ATuple(cursor)); // Compile all expressions into it - Value* val = NULL; + CValue val = NULL; for (list< pair >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) - val = LLVal(cenv.compile(i->second)); + val = cenv.compile(i->second); - // Finish function - llengine(cenv)->builder.CreateRet(val); - cenv.optimise(f); + finishFunction(cenv, f, val); - cenv.out << call(resultType, llengine(cenv)->engine->getPointerToFunction(f)) - << " : " << resultType << endl; + cenv.out << call(cenv, f, resultType) << " : " << resultType << endl; Object::pool.collect(Object::pool.roots()); @@ -603,20 +605,20 @@ repl(CEnv& cenv) body->lift(cenv); - if (lltype(bodyT)) { + CFunction f = NULL; + try { // Create anonymous function to insert code into - Function* f = compileFunction(cenv, cenv.gensym("_repl"), lltype(bodyT), ATuple(cursor), cursor); - try { - Value* retVal = LLVal(cenv.compile(body)); - llengine(cenv)->builder.CreateRet(retVal); // Finish function - cenv.optimise(f); - } catch (Error& e) { - f->eraseFromParent(); // Error reading body, remove function - throw e; - } - cenv.out << call(bodyT, llengine(cenv)->engine->getPointerToFunction(f)); - } else { - cenv.out << "; " << cenv.compile(body); + f = startFunction(cenv, cenv.gensym("_repl"), bodyT, ATuple(cursor)); + CValue retVal = cenv.compile(body); + finishFunction(cenv, f, retVal); + cenv.out << call(cenv, f, bodyT); + } catch (Error& e) { + ADefinition* def = body->to(); + if (def) + cenv.out << def->sym(); + else + cenv.out << "?"; + eraseFunction(cenv, f); } cenv.out << " : " << cenv.type(body) << endl; -- cgit v1.2.1