aboutsummaryrefslogtreecommitdiffstats
path: root/src/llvm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/llvm.cpp')
-rw-r--r--src/llvm.cpp424
1 files changed, 424 insertions, 0 deletions
diff --git a/src/llvm.cpp b/src/llvm.cpp
new file mode 100644
index 0000000..e2f7f1a
--- /dev/null
+++ b/src/llvm.cpp
@@ -0,0 +1,424 @@
+/* Tuplr: A programming language
+ * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
+ *
+ * Tuplr is free software: you can redistribute it and/or modify it under
+ * the terms of the GNU Affero General Public License as published by the
+ * Free Software Foundation, either version 3 of the License, or (at your
+ * option) any later version.
+ *
+ * Tuplr is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
+ * Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Tuplr. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+/** @file
+ * @brief Compile AST to LLVM IR
+ *
+ * Compilation pass functions (lift/compile) that require direct use of LLVM
+ * specific things are implemented here. Generic compilation pass functions
+ * are implemented in compile.cpp.
+ */
+
+#include <map>
+#include <sstream>
+#include <boost/format.hpp>
+#include "llvm/Analysis/Verifier.h"
+#include "llvm/Assembly/AsmAnnotationWriter.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/Instructions.h"
+#include "llvm/Module.h"
+#include "llvm/ModuleProvider.h"
+#include "llvm/PassManager.h"
+#include "llvm/Support/IRBuilder.h"
+#include "llvm/Target/TargetData.h"
+#include "llvm/Transforms/Scalar.h"
+#include "tuplr.hpp"
+
+using namespace llvm;
+using namespace std;
+using boost::format;
+
+static inline Value* llVal(CValue v) { return static_cast<Value*>(v); }
+static inline Function* llFunc(CFunction f) { return static_cast<Function*>(f); }
+
+static const Type*
+llType(const AType* t)
+{
+ if (t->kind == AType::PRIM) {
+ if (t->at(0)->str() == "Nothing") return Type::VoidTy;
+ if (t->at(0)->str() == "Bool") return Type::Int1Ty;
+ if (t->at(0)->str() == "Int") return Type::Int32Ty;
+ if (t->at(0)->str() == "Float") return Type::FloatTy;
+ throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'");
+ } else if (t->kind == AType::EXPR && t->at(0)->str() == "Fn") {
+ const AType* retT = t->at(2)->as<const AType*>();
+ if (!llType(retT))
+ return NULL;
+
+ vector<const Type*> cprot;
+ const ATuple* prot = t->at(1)->to<const ATuple*>();
+ for (size_t i = 0; i < prot->size(); ++i) {
+ const AType* at = prot->at(i)->to<const AType*>();
+ const Type* lt = llType(at);
+ if (lt)
+ cprot.push_back(lt);
+ else
+ return NULL;
+ }
+
+ FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
+ return PointerType::get(fT, 0);
+ }
+ return NULL; // non-primitive type
+}
+
+
+/***************************************************************************
+ * LLVM Engine *
+ ***************************************************************************/
+
+struct LLVMEngine : public Engine {
+ LLVMEngine()
+ : module(new Module("tuplr"))
+ , engine(ExecutionEngine::create(module))
+ , emp(module)
+ , opt(&emp)
+ {
+ // Set up optimiser pipeline
+ const TargetData* target = engine->getTargetData();
+ opt.add(new TargetData(*target)); // Register target arch
+ opt.add(createInstructionCombiningPass()); // Simple optimizations
+ opt.add(createReassociatePass()); // Reassociate expressions
+ opt.add(createGVNPass()); // Eliminate Common Subexpressions
+ opt.add(createCFGSimplificationPass()); // Simplify control flow
+
+ // Declare host provided allocation primitive
+ std::vector<const Type*> argsT(1, Type::Int32Ty); // unsigned size
+ argsT.push_back(Type::Int8Ty); // char tag
+ FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false);
+ alloc = Function::Create(funcT, Function::ExternalLinkage,
+ "tuplr_gc_allocate", module);
+ }
+
+ CFunction startFunction(CEnv& cenv,
+ const std::string& name, const AType* retT, const ATuple& argsT,
+ const vector<string> argNames)
+ {
+ Function::LinkageTypes linkage = Function::ExternalLinkage;
+
+ vector<const Type*> cprot;
+ FOREACH(ATuple::const_iterator, i, argsT) {
+ AType* at = (*i)->as<AType*>();
+ THROW_IF(!llType(at), Cursor(), string("non-concrete parameter :: ")
+ + at->str())
+ cprot.push_back(llType(at));
+ }
+
+ THROW_IF(!llType(retT), Cursor(), "return has non-concrete type");
+ FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
+ Function* f = Function::Create(fT, linkage, name, module);
+
+ // Note f->getName() may be different from name
+ // however LLVM chooses to mangle is fine, we keep a pointer
+
+ // Set argument names in generated code
+ Function::arg_iterator a = f->arg_begin();
+ if (!argNames.empty())
+ for (size_t i = 0; i != argsT.size(); ++a, ++i)
+ a->setName(argNames.at(i));
+
+ BasicBlock* bb = BasicBlock::Create("entry", f);
+ builder.SetInsertPoint(bb);
+
+ return f;
+ }
+
+ void finishFunction(CEnv& cenv, CFunction f, const AType* retT, CValue ret) {
+ if (retT->concrete())
+ builder.CreateRet(llVal(ret));
+ else
+ builder.CreateRetVoid();
+
+ verifyFunction(*static_cast<Function*>(f));
+ if (cenv.args.find("-g") == cenv.args.end())
+ opt.run(*static_cast<Function*>(f));
+ }
+
+ void eraseFunction(CEnv& cenv, CFunction f) {
+ if (f)
+ llFunc(f)->eraseFromParent();
+ }
+
+ CValue compileCall(CEnv& cenv, CFunction f, const vector<CValue>& args) {
+ const vector<Value*>& llArgs = *reinterpret_cast<const vector<Value*>*>(&args);
+ return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end());
+ }
+
+ void liftCall(CEnv& cenv, AFn* fn, const AType& argsT);
+
+ CValue compileLiteral(CEnv& cenv, AST* lit);
+ CValue compilePrimitive(CEnv& cenv, APrimitive* prim);
+ CValue compileIf(CEnv& cenv, AIf* aif);
+
+ void writeModule(CEnv& cenv, std::ostream& os) {
+ AssemblyAnnotationWriter writer;
+ module->print(os, &writer);
+ }
+
+ const string call(CEnv& cenv, CFunction f, AType* retT) {
+ void* fp = engine->getPointerToFunction(llFunc(f));
+ const Type* t = llType(retT);
+ THROW_IF(!fp, Cursor(), "unable to get function pointer");
+ THROW_IF(!t, Cursor(), "function with non-concrete return type called");
+
+ std::stringstream ss;
+ if (t == Type::Int32Ty)
+ ss << ((int32_t (*)())fp)();
+ else if (t == Type::FloatTy)
+ ss << showpoint << ((float (*)())fp)();
+ else if (t == Type::Int1Ty)
+ ss << (((bool (*)())fp)() ? "#t" : "#f");
+ else if (t != Type::VoidTy)
+ ss << ((void* (*)())fp)();
+ return ss.str();
+ }
+
+ Module* module;
+ ExecutionEngine* engine;
+ IRBuilder<> builder;
+ Function* alloc;
+ ExistingModuleProvider emp;
+ FunctionPassManager opt;
+};
+
+Engine*
+tuplr_new_llvm_engine()
+{
+ return new LLVMEngine();
+}
+
+/***************************************************************************
+ * Code Generation *
+ ***************************************************************************/
+
+CValue
+LLVMEngine::compileLiteral(CEnv& cenv, AST* lit)
+{
+ ALiteral<int32_t>* ilit = dynamic_cast<ALiteral<int32_t>*>(lit);
+ if (ilit)
+ return ConstantInt::get(Type::Int32Ty, ilit->val, true);
+
+ ALiteral<float>* flit = dynamic_cast<ALiteral<float>*>(lit);
+ if (flit)
+ return ConstantFP::get(Type::FloatTy, flit->val);
+
+ ALiteral<bool>* blit = dynamic_cast<ALiteral<bool>*>(lit);
+ if (blit)
+ return ConstantFP::get(Type::FloatTy, blit->val);
+
+ throw Error(lit->loc, "Unknown literal type");
+}
+
+void
+LLVMEngine::liftCall(CEnv& cenv, AFn* fn, const AType& argsT)
+{
+ TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(fn);
+ assert(gt != cenv.tenv.genericTypes.end());
+ LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
+ AType* genericType = new AType(*gt->second);
+ AType* thisType = genericType;
+ Subst argsSubst;
+
+ // Build and apply substitution to get concrete type for this call
+ if (!genericType->concrete()) {
+ argsSubst = cenv.tenv.buildSubst(genericType, argsT);
+ thisType = argsSubst.apply(genericType)->as<AType*>();
+ }
+
+ THROW_IF(!thisType->concrete(), fn->loc,
+ string("call has non-concrete type %1%\n") + thisType->str());
+
+ Object::pool.addRoot(thisType);
+ if (fn->impls.find(thisType))
+ return;
+
+ ATuple* protT = thisType->at(1)->as<ATuple*>();
+
+ vector<string> argNames;
+ for (size_t i = 0; i < fn->prot()->size(); ++i)
+ argNames.push_back(fn->prot()->at(i)->str());
+
+ // Write function declaration
+ const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name;
+ Function* f = llFunc(cenv.engine()->startFunction(cenv, name,
+ thisType->at(thisType->size()-1)->to<AType*>(),
+ *protT, argNames));
+
+ cenv.push();
+ Subst oldSubst = cenv.tsubst;
+ cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, fn->subst));
+
+//#define EXPLICIT_STACK_FRAMES 1
+
+#ifdef EXPLICIT_STACK_FRAMES
+ vector<const Type*> types;
+ types.push_back(Type::Int8Ty);
+ types.push_back(Type::Int8Ty);
+ size_t s = 16; // stack frame size in bits
+#endif
+
+ // Bind argument values in CEnv
+ vector<Value*> args;
+ AFn::const_iterator p = fn->prot()->begin();
+ size_t i = 0;
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) {
+ AType* t = protT->at(i)->as<AType*>();
+ const Type* lt = llType(t);
+ THROW_IF(!lt, fn->loc, "untyped parameter\n");
+ cenv.def((*p)->as<ASymbol*>(), *p, t, &*a);
+#ifdef EXPLICIT_STACK_FRAMES
+ types.push_back(lt);
+ s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8));
+#endif
+ }
+
+#ifdef EXPLICIT_STACK_FRAMES
+ IRBuilder<> builder = engine->builder;
+
+ // Scan out definitions
+ for (size_t i = 0; i < size(); ++i) {
+ ADef* def = at(i)->to<ADef*>();
+ if (def) {
+ const Type* lt = llType(cenv.type(def->at(2)));
+ THROW_IF(!lt, loc, "untyped definition\n");
+ types.push_back(lt);
+ s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8));
+ }
+ }
+
+ // Create stack frame
+ StructType* frameT = StructType::get(types, false);
+ Value* tag = ConstantInt::get(Type::Int8Ty, GC::TAG_FRAME);
+ Value* frameSize = ConstantInt::get(Type::Int32Ty, s / 8);
+ Value* frame = builder.CreateCall2(engine->alloc, frameSize, tag, "frame");
+ Value* framePtr = builder.CreateBitCast(frame, PointerType::get(frameT, 0));
+
+ // Bind parameter values in stack frame
+ i = 2;
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++i) {
+ Value* v = builder.CreateStructGEP(framePtr, i, "arg");
+ builder.CreateStore(&*a, v);
+ }
+#endif
+
+ // Write function body
+ try {
+ // Define value first for recursion
+ cenv.precompile(fn, f);
+ fn->impls.push_back(make_pair(thisType, f));
+ CValue retVal = NULL;
+ for (size_t i = 2; i < fn->size(); ++i)
+ retVal = cenv.compile(fn->at(i));
+ cenv.engine()->finishFunction(cenv, f, cenv.type(fn->at(fn->size() - 1)), retVal);
+ } catch (Error& e) {
+ f->eraseFromParent(); // Error reading body, remove function
+ cenv.pop();
+ throw e;
+ }
+ cenv.tsubst = oldSubst;
+ cenv.pop();
+}
+
+CValue
+LLVMEngine::compileIf(CEnv& cenv, AIf* aif)
+{
+ typedef vector< pair<Value*, BasicBlock*> > Branches;
+ LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
+ Function* parent = engine->builder.GetInsertBlock()->getParent();
+ BasicBlock* mergeBB = BasicBlock::Create("endif");
+ BasicBlock* nextBB = NULL;
+ Branches branches;
+ for (size_t i = 1; i < aif->size() - 1; i += 2) {
+ Value* condV = llVal(cenv.compile(aif->at(i)));
+ BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str());
+
+ nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str());
+
+ engine->builder.CreateCondBr(condV, thenBB, nextBB);
+
+ // Emit then block for this condition
+ parent->getBasicBlockList().push_back(thenBB);
+ engine->builder.SetInsertPoint(thenBB);
+ Value* thenV = llVal(cenv.compile(aif->at(i + 1)));
+ engine->builder.CreateBr(mergeBB);
+ branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock()));
+
+ parent->getBasicBlockList().push_back(nextBB);
+ engine->builder.SetInsertPoint(nextBB);
+ }
+
+ // Emit final else block
+ engine->builder.SetInsertPoint(nextBB);
+ Value* elseV = llVal(cenv.compile(aif->at(aif->size() - 1)));
+ engine->builder.CreateBr(mergeBB);
+ branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock()));
+
+ // Emit merge block (Phi node)
+ parent->getBasicBlockList().push_back(mergeBB);
+ engine->builder.SetInsertPoint(mergeBB);
+ PHINode* pn = engine->builder.CreatePHI(llType(cenv.type(aif)), "ifval");
+
+ FOREACH(Branches::iterator, i, branches)
+ pn->addIncoming(i->first, i->second);
+
+ return pn;
+}
+
+CValue
+LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim)
+{
+ LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
+ Value* a = llVal(cenv.compile(prim->at(1)));
+ Value* b = llVal(cenv.compile(prim->at(2)));
+ bool isFloat = cenv.type(prim->at(1))->str() == "Float";
+ const string n = prim->at(0)->to<ASymbol*>()->str();
+
+ // Binary arithmetic operations
+ Instruction::BinaryOps op = (Instruction::BinaryOps)0;
+ if (n == "+") op = Instruction::Add;
+ if (n == "-") op = Instruction::Sub;
+ if (n == "*") op = Instruction::Mul;
+ if (n == "and") op = Instruction::And;
+ if (n == "or") op = Instruction::Or;
+ if (n == "xor") op = Instruction::Xor;
+ if (n == "/") op = isFloat ? Instruction::FDiv : Instruction::SDiv;
+ if (n == "%") op = isFloat ? Instruction::FRem : Instruction::SRem;
+ if (op != 0) {
+ Value* val = engine->builder.CreateBinOp(op, a, b);
+ for (size_t i = 3; i < prim->size(); ++i)
+ val = engine->builder.CreateBinOp(op, val, llVal(cenv.compile(prim->at(i))));
+ return val;
+ }
+
+ // Comparison operations
+ CmpInst::Predicate pred = (CmpInst::Predicate)0;
+ if (n == "=") pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ;
+ if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ;
+ if (n == ">") pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT;
+ if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE;
+ if (n == "<") pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT;
+ if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE;
+ if (pred != 0) {
+ if (isFloat)
+ return engine->builder.CreateFCmp(pred, a, b);
+ else
+ return engine->builder.CreateICmp(pred, a, b);
+ }
+
+ throw Error(prim->loc, "unknown primitive");
+}