aboutsummaryrefslogtreecommitdiffstats
path: root/src/llvm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/llvm.cpp')
-rw-r--r--src/llvm.cpp544
1 files changed, 544 insertions, 0 deletions
diff --git a/src/llvm.cpp b/src/llvm.cpp
new file mode 100644
index 0000000..b5e397e
--- /dev/null
+++ b/src/llvm.cpp
@@ -0,0 +1,544 @@
+/* Tuplr: A programming language
+ * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
+ *
+ * Tuplr is free software: you can redistribute it and/or modify it under
+ * the terms of the GNU Affero General Public License as published by the
+ * Free Software Foundation, either version 3 of the License, or (at your
+ * option) any later version.
+ *
+ * Tuplr is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
+ * Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Tuplr. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <map>
+#include <sstream>
+#include <boost/format.hpp>
+#include "llvm/Analysis/Verifier.h"
+#include "llvm/Assembly/AsmAnnotationWriter.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/Instructions.h"
+#include "llvm/Module.h"
+#include "llvm/ModuleProvider.h"
+#include "llvm/PassManager.h"
+#include "llvm/Support/IRBuilder.h"
+#include "llvm/Target/TargetData.h"
+#include "llvm/Transforms/Scalar.h"
+#include "tuplr.hpp"
+
+using namespace llvm;
+using namespace std;
+using boost::format;
+
+static inline Value* llVal(CValue v) { return static_cast<Value*>(v); }
+static inline Function* llFunc(CFunction f) { return static_cast<Function*>(f); }
+
+static const Type*
+llType(const AType* t)
+{
+ if (t->kind == AType::PRIM) {
+ if (t->at(0)->str() == "Nothing") return Type::VoidTy;
+ if (t->at(0)->str() == "Bool") return Type::Int1Ty;
+ if (t->at(0)->str() == "Int") return Type::Int32Ty;
+ if (t->at(0)->str() == "Float") return Type::FloatTy;
+ throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'");
+ /*} else if (t->kind == AType::EXPR && t->at(0)->str() == "Fn") {
+ AType* retT = t->at(2)->as<AType*>();
+ if (!llType(retT))
+ return NULL;
+
+ vector<const Type*> cprot;
+ const ATuple* prot = t->at(1)->to<ATuple*>();
+ for (size_t i = 0; i < prot->size(); ++i) {
+ AType* at = prot->at(i)->to<AType*>();
+ const Type* lt = llType(at);
+ if (lt)
+ cprot.push_back(lt);
+ else
+ return NULL;
+ }
+
+ FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
+ return PointerType::get(fT, 0);
+ */
+ }
+ return NULL; // non-primitive type
+}
+
+
+/***************************************************************************
+ * LLVM Engine *
+ ***************************************************************************/
+
+struct LLVMEngine : public Engine {
+ LLVMEngine()
+ : module(new Module("tuplr"))
+ , engine(ExecutionEngine::create(module))
+ , emp(module)
+ , opt(&emp)
+ {
+ // Set up optimiser pipeline
+ const TargetData* target = engine->getTargetData();
+ opt.add(new TargetData(*target)); // Register target arch
+ opt.add(createInstructionCombiningPass()); // Simple optimizations
+ opt.add(createReassociatePass()); // Reassociate expressions
+ opt.add(createGVNPass()); // Eliminate Common Subexpressions
+ opt.add(createCFGSimplificationPass()); // Simplify control flow
+
+ // Declare host provided allocation primitive
+ std::vector<const Type*> argsT(1, Type::Int32Ty); // unsigned size
+ argsT.push_back(Type::Int8Ty); // char tag
+ FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false);
+ alloc = Function::Create(funcT, Function::ExternalLinkage,
+ "tuplr_gc_allocate", module);
+ }
+
+ CFunction startFunction(CEnv& cenv,
+ const std::string& name, const AType* retT, const ATuple& argsT,
+ const vector<string> argNames)
+ {
+ Function::LinkageTypes linkage = Function::ExternalLinkage;
+
+ vector<const Type*> cprot;
+ for (size_t i = 0; i < argsT.size(); ++i) {
+ AType* at = argsT.at(i)->as<AType*>();
+ THROW_IF(!llType(at), Cursor(), string("parameter has non-concrete type ")
+ + at->str())
+ cprot.push_back(llType(at));
+ }
+
+ THROW_IF(!llType(retT), Cursor(), "return has non-concrete type");
+ FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
+ Function* f = Function::Create(fT, linkage, name, module);
+
+ // Note f->getName() may be different from name
+ // however LLVM chooses to mangle is fine, we keep a pointer
+
+ // Set argument names in generated code
+ Function::arg_iterator a = f->arg_begin();
+ if (!argNames.empty())
+ for (size_t i = 0; i != argsT.size(); ++a, ++i)
+ a->setName(argNames.at(i));
+
+ BasicBlock* bb = BasicBlock::Create("entry", f);
+ builder.SetInsertPoint(bb);
+
+ return f;
+ }
+
+ void finishFunction(CEnv& cenv, CFunction f, const AType* retT, CValue ret) {
+ if (retT->concrete()) {
+ Value* retVal = llVal(ret);
+ builder.CreateRet(retVal);
+ } else {
+ builder.CreateRetVoid();
+ }
+
+ /*std::cerr << "MODULE {" << endl;
+ module->dump();
+ std::cerr << "}" << endl;*/
+ verifyFunction(*static_cast<Function*>(f));
+ if (cenv.args.find("-g") == cenv.args.end())
+ opt.run(*static_cast<Function*>(f));
+ }
+
+ void eraseFunction(CEnv& cenv, CFunction f) {
+ if (f)
+ llFunc(f)->eraseFromParent();
+ }
+
+ void writeModule(CEnv& cenv, std::ostream& os) {
+ AssemblyAnnotationWriter writer;
+ module->print(os, &writer);
+ }
+
+ const string call(CEnv& cenv, CFunction f, AType* retT) {
+ void* fp = engine->getPointerToFunction(llFunc(f));
+ const Type* t = llType(retT);
+ THROW_IF(!fp, Cursor(), "unable to get function pointer");
+ THROW_IF(!t, Cursor(), "function with non-concrete return type called");
+
+ std::stringstream ss;
+ if (t == Type::Int32Ty)
+ ss << ((int32_t (*)())fp)();
+ else if (t == Type::FloatTy)
+ ss << showpoint << ((float (*)())fp)();
+ else if (t == Type::Int1Ty)
+ ss << (((bool (*)())fp)() ? "#t" : "#f");
+ else
+ ss << ((void* (*)())fp)();
+ return ss.str();
+ }
+
+ Module* module;
+ ExecutionEngine* engine;
+ IRBuilder<> builder;
+ CFunction alloc;
+ ExistingModuleProvider emp;
+ FunctionPassManager opt;
+};
+
+static LLVMEngine*
+llEngine(CEnv& cenv)
+{
+ return reinterpret_cast<LLVMEngine*>(cenv.engine());
+}
+
+/// Shared library entry point
+Engine*
+tuplr_new_engine()
+{
+ return new LLVMEngine();
+}
+
+/// Shared library entry point
+void
+tuplr_free_engine(Engine* engine)
+{
+ delete (LLVMEngine*)engine;
+}
+
+
+/***************************************************************************
+ * Code Generation *
+ ***************************************************************************/
+
+#define LITERAL(CT, NAME, COMPILED) \
+template<> CValue ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
+template<> void \
+ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) const { \
+ c.constrain(tenv, this, tenv.named(NAME)); \
+}
+
+/// Literal template instantiations
+LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true))
+LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val))
+LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false))
+
+CValue
+ASymbol::compile(CEnv& cenv)
+{
+ return cenv.vals.ref(this);
+}
+
+void
+AFn::lift(CEnv& cenv)
+{
+ cenv.push();
+ for (const_iterator p = prot()->begin(); p != prot()->end(); ++p)
+ cenv.def((*p)->as<ASymbol*>(), *p, NULL, NULL);
+
+ // Lift body
+ for (size_t i = 2; i < size(); ++i)
+ at(i)->lift(cenv);
+
+ cenv.pop();
+
+ AType* type = cenv.type(this);
+ if (impls.find(type) || !type->concrete())
+ return;
+
+ AType* protT = type->at(1)->as<AType*>();
+ liftCall(cenv, *protT);
+}
+
+void
+AFn::liftCall(CEnv& cenv, const AType& argsT)
+{
+ TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(this);
+ assert(gt != cenv.tenv.genericTypes.end());
+ AType* genericType = new AType(*gt->second);
+ AType* thisType = genericType;
+ Subst argsSubst;
+
+ if (!genericType->concrete()) {
+ // Build substitution to apply to generic type
+ assert(argsT.size() == prot()->size());
+ ATuple* genericProtT = gt->second->at(1)->as<ATuple*>();
+ for (size_t i = 0; i < argsT.size(); ++i) {
+ const AType* genericArgT = genericProtT->at(i)->to<const AType*>();
+ AType* callArgT = argsT.at(i)->to<AType*>();
+ assert(genericArgT);
+ assert(callArgT);
+ if (callArgT->kind == AType::EXPR) {
+ assert(genericArgT->kind == AType::EXPR);
+ assert(callArgT->size() == genericArgT->size());
+ for (size_t i = 0; i < callArgT->size(); ++i) {
+ AType* gT = genericArgT->at(i)->to<AType*>();
+ AType* aT = callArgT->at(i)->to<AType*>();
+ if (gT && aT)
+ argsSubst.add(gT, aT);
+ }
+ } else {
+ argsSubst.add(genericArgT, callArgT);
+ }
+ }
+
+ // Apply substitution to get concrete type for this call
+ thisType = argsSubst.apply(genericType)->as<AType*>();
+ THROW_IF(!thisType->concrete(), loc,
+ string("unable to resolve concrete type for function :: ")
+ + thisType->str() + "\n" + this->str());
+ }
+
+ Object::pool.addRoot(thisType);
+ if (impls.find(thisType))
+ return;
+
+ ATuple* protT = thisType->at(1)->as<ATuple*>();
+
+ vector<string> argNames;
+ for (size_t i = 0; i < prot()->size(); ++i) {
+ argNames.push_back(prot()->at(i)->str());
+ }
+
+ // Write function declaration
+ const string name = (this->name == "") ? cenv.penv.gensymstr("_fn") : this->name;
+ Function* f = llFunc(cenv.engine()->startFunction(cenv, name,
+ thisType->at(thisType->size()-1)->to<AType*>(),
+ *protT, argNames));
+
+ cenv.push();
+ Subst oldSubst = cenv.tsubst;
+ cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, subst));
+
+//#define EXPLICIT_STACK_FRAMES 1
+
+#ifdef EXPLICIT_STACK_FRAMES
+ vector<const Type*> types;
+ types.push_back(Type::Int8Ty);
+ types.push_back(Type::Int8Ty);
+ size_t s = 16; // stack frame size in bits
+#endif
+
+ // Bind argument values in CEnv
+ vector<Value*> args;
+ const_iterator p = prot()->begin();
+ size_t i = 0;
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) {
+ AType* t = protT->at(i)->as<AType*>();
+ const Type* lt = llType(t);
+ THROW_IF(!lt, loc, "untyped parameter\n");
+ cenv.def((*p)->as<ASymbol*>(), *p, t, &*a);
+#ifdef EXPLICIT_STACK_FRAMES
+ types.push_back(lt);
+ s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8));
+#endif
+ }
+
+
+#ifdef EXPLICIT_STACK_FRAMES
+ IRBuilder<> builder = llEngine(cenv)->builder;
+
+ // Scan out definitions
+ for (size_t i = 0; i < size(); ++i) {
+ ADef* def = at(i)->to<ADef*>();
+ if (def) {
+ const Type* lt = llType(cenv.type(def->at(2)));
+ THROW_IF(!lt, loc, "untyped definition\n");
+ types.push_back(lt);
+ s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8));
+ }
+ }
+
+ // Create stack frame
+ StructType* frameT = StructType::get(types, false);
+ Value* tag = ConstantInt::get(Type::Int8Ty, GC::TAG_FRAME);
+ Value* frameSize = ConstantInt::get(Type::Int32Ty, s / 8);
+ Value* frame = builder.CreateCall2(llVal(cenv.alloc), frameSize, tag, "frame");
+ Value* framePtr = builder.CreateBitCast(frame, PointerType::get(frameT, 0));
+
+ // Bind parameter values in stack frame
+ i = 2;
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++i) {
+ Value* v = builder.CreateStructGEP(framePtr, i, "arg");
+ builder.CreateStore(&*a, v);
+ }
+#endif
+
+ // Write function body
+ try {
+ // Define value first for recursion
+ cenv.precompile(this, f);
+ impls.push_back(make_pair(thisType, f));
+ CValue retVal = NULL;
+ for (size_t i = 2; i < size(); ++i)
+ retVal = cenv.compile(at(i));
+ cenv.engine()->finishFunction(cenv, f, cenv.type(at(size()-1)), retVal);
+ } catch (Error& e) {
+ f->eraseFromParent(); // Error reading body, remove function
+ cenv.pop();
+ throw e;
+ }
+ cenv.tsubst = oldSubst;
+ cenv.pop();
+}
+
+CValue
+AFn::compile(CEnv& cenv)
+{
+ return NULL;
+}
+
+void
+ACall::lift(CEnv& cenv)
+{
+ AFn* c = cenv.tenv.resolve(at(0))->to<AFn*>();
+ AType argsT(loc, NULL);
+
+ // Lift arguments
+ for (size_t i = 1; i < size(); ++i) {
+ at(i)->lift(cenv);
+ argsT.push_back(cenv.type(at(i)));
+ }
+
+ if (!c) return; // Primitive
+
+ if (c->prot()->size() < size() - 1)
+ throw Error(loc, (format("too many arguments to function `%1%'") % at(0)->str()).str());
+ if (c->prot()->size() > size() - 1)
+ throw Error(loc, (format("too few arguments to function `%1%'") % at(0)->str()).str());
+
+ c->liftCall(cenv, argsT); // Lift called closure
+}
+
+CValue
+ACall::compile(CEnv& cenv)
+{
+ AFn* c = cenv.tenv.resolve(at(0))->to<AFn*>();
+
+ if (!c) return NULL; // Primitive
+
+ AType protT(loc, NULL);
+ vector<const Type*> types;
+ for (size_t i = 1; i < size(); ++i) {
+ protT.push_back(cenv.type(at(i)));
+ types.push_back(llType(cenv.type(at(i))));
+ }
+
+ TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(c);
+ assert(gt != cenv.tenv.genericTypes.end());
+ AType fnT(loc, cenv.penv.sym("Fn"), &protT, cenv.type(this), 0);
+ Function* f = (Function*)c->impls.find(&fnT);
+ THROW_IF(!f, loc, (format("callee failed to compile for type %1%") % fnT.str()).str());
+
+ vector<Value*> params(size() - 1);
+ for (size_t i = 0; i < types.size(); ++i)
+ params[i] = llVal(cenv.compile(at(i+1)));
+
+ return llEngine(cenv)->builder.CreateCall(f, params.begin(), params.end());
+}
+
+void
+ADef::lift(CEnv& cenv)
+{
+ // Define stub first for recursion
+ cenv.def(sym(), at(2), cenv.type(at(2)), NULL);
+ AFn* c = at(2)->to<AFn*>();
+ if (c)
+ c->name = sym()->str();
+ at(2)->lift(cenv);
+}
+
+CValue
+ADef::compile(CEnv& cenv)
+{
+ // Define stub first for recursion
+ cenv.def(sym(), at(2), cenv.type(at(2)), NULL);
+ CValue val = cenv.compile(at(size() - 1));
+ cenv.vals.def(sym(), val);
+ return val;
+}
+
+CValue
+AIf::compile(CEnv& cenv)
+{
+ typedef vector< pair<Value*, BasicBlock*> > Branches;
+ Function* parent = llEngine(cenv)->builder.GetInsertBlock()->getParent();
+ BasicBlock* mergeBB = BasicBlock::Create("endif");
+ BasicBlock* nextBB = NULL;
+ Branches branches;
+ for (size_t i = 1; i < size() - 1; i += 2) {
+ Value* condV = llVal(cenv.compile(at(i)));
+ BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str());
+
+ nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str());
+
+ llEngine(cenv)->builder.CreateCondBr(condV, thenBB, nextBB);
+
+ // Emit then block for this condition
+ parent->getBasicBlockList().push_back(thenBB);
+ llEngine(cenv)->builder.SetInsertPoint(thenBB);
+ Value* thenV = llVal(cenv.compile(at(i+1)));
+ llEngine(cenv)->builder.CreateBr(mergeBB);
+ branches.push_back(make_pair(thenV, llEngine(cenv)->builder.GetInsertBlock()));
+
+ parent->getBasicBlockList().push_back(nextBB);
+ llEngine(cenv)->builder.SetInsertPoint(nextBB);
+ }
+
+ // Emit final else block
+ llEngine(cenv)->builder.SetInsertPoint(nextBB);
+ Value* elseV = llVal(cenv.compile(at(size() - 1)));
+ llEngine(cenv)->builder.CreateBr(mergeBB);
+ branches.push_back(make_pair(elseV, llEngine(cenv)->builder.GetInsertBlock()));
+
+ // Emit merge block (Phi node)
+ parent->getBasicBlockList().push_back(mergeBB);
+ llEngine(cenv)->builder.SetInsertPoint(mergeBB);
+ PHINode* pn = llEngine(cenv)->builder.CreatePHI(llType(cenv.type(this)), "ifval");
+
+ FOREACH(Branches::iterator, i, branches)
+ pn->addIncoming(i->first, i->second);
+
+ return pn;
+}
+
+CValue
+APrimitive::compile(CEnv& cenv)
+{
+ Value* a = llVal(cenv.compile(at(1)));
+ Value* b = llVal(cenv.compile(at(2)));
+ bool isFloat = cenv.type(at(1))->str() == "Float";
+ const string n = at(0)->to<ASymbol*>()->str();
+
+ // Binary arithmetic operations
+ Instruction::BinaryOps op = (Instruction::BinaryOps)0;
+ if (n == "+") op = Instruction::Add;
+ if (n == "-") op = Instruction::Sub;
+ if (n == "*") op = Instruction::Mul;
+ if (n == "and") op = Instruction::And;
+ if (n == "or") op = Instruction::Or;
+ if (n == "xor") op = Instruction::Xor;
+ if (n == "/") op = isFloat ? Instruction::FDiv : Instruction::SDiv;
+ if (n == "%") op = isFloat ? Instruction::FRem : Instruction::SRem;
+ if (op != 0) {
+ Value* val = llEngine(cenv)->builder.CreateBinOp(op, a, b);
+ for (size_t i = 3; i < size(); ++i)
+ val = llEngine(cenv)->builder.CreateBinOp(op, val, llVal(cenv.compile(at(i))));
+ return val;
+ }
+
+ // Comparison operations
+ CmpInst::Predicate pred = (CmpInst::Predicate)0;
+ if (n == "=") pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ;
+ if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ;
+ if (n == ">") pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT;
+ if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE;
+ if (n == "<") pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT;
+ if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE;
+ if (pred != 0) {
+ if (isFloat)
+ return llEngine(cenv)->builder.CreateFCmp(pred, a, b);
+ else
+ return llEngine(cenv)->builder.CreateICmp(pred, a, b);
+ }
+
+ throw Error(loc, "unknown primitive");
+}
+