/* Tuplr: A programming language * Copyright (C) 2008-2009 David Robillard * * Tuplr is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Tuplr is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Tuplr. If not, see . */ /** @file * @brief Compile AST to LLVM IR * * Compilation pass functions (lift/compile) that require direct use of LLVM * specific things are implemented here. Generic compilation pass functions * are implemented in compile.cpp. */ #include #include #include #include "llvm/Analysis/Verifier.h" #include "llvm/Assembly/AsmAnnotationWriter.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" #include "tuplr.hpp" using namespace llvm; using namespace std; using boost::format; static inline Value* llVal(CValue v) { return static_cast(v); } static inline Function* llFunc(CFunction f) { return static_cast(f); } static const Type* llType(const AType* t) { if (t->kind == AType::PRIM) { if (t->at(0)->str() == "Nothing") return Type::VoidTy; if (t->at(0)->str() == "Bool") return Type::Int1Ty; if (t->at(0)->str() == "Int") return Type::Int32Ty; if (t->at(0)->str() == "Float") return Type::FloatTy; throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'"); } else if (t->kind == AType::EXPR && t->at(0)->str() == "Fn") { const AType* retT = t->at(2)->as(); if (!llType(retT)) return NULL; vector cprot; const ATuple* prot = t->at(1)->to(); for (size_t i = 0; i < prot->size(); ++i) { const AType* at = prot->at(i)->to(); const Type* lt = llType(at); if (lt) cprot.push_back(lt); else return NULL; } FunctionType* fT = FunctionType::get(llType(retT), cprot, false); return PointerType::get(fT, 0); } return NULL; // non-primitive type } /*************************************************************************** * LLVM Engine * ***************************************************************************/ struct LLVMEngine : public Engine { LLVMEngine() : module(new Module("tuplr")) , engine(ExecutionEngine::create(module)) , emp(module) , opt(&emp) { // Set up optimiser pipeline const TargetData* target = engine->getTargetData(); opt.add(new TargetData(*target)); // Register target arch opt.add(createInstructionCombiningPass()); // Simple optimizations opt.add(createReassociatePass()); // Reassociate expressions opt.add(createGVNPass()); // Eliminate Common Subexpressions opt.add(createCFGSimplificationPass()); // Simplify control flow // Declare host provided allocation primitive std::vector argsT(1, Type::Int32Ty); // unsigned size argsT.push_back(Type::Int8Ty); // char tag FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false); alloc = Function::Create(funcT, Function::ExternalLinkage, "tuplr_gc_allocate", module); } CFunction startFunction(CEnv& cenv, const std::string& name, const AType* retT, const ATuple& argsT, const vector argNames) { Function::LinkageTypes linkage = Function::ExternalLinkage; vector cprot; FOREACH(ATuple::const_iterator, i, argsT) { AType* at = (*i)->as(); THROW_IF(!llType(at), Cursor(), string("non-concrete parameter :: ") + at->str()) cprot.push_back(llType(at)); } THROW_IF(!llType(retT), Cursor(), "return has non-concrete type"); FunctionType* fT = FunctionType::get(llType(retT), cprot, false); Function* f = Function::Create(fT, linkage, name, module); // Note f->getName() may be different from name // however LLVM chooses to mangle is fine, we keep a pointer // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); if (!argNames.empty()) for (size_t i = 0; i != argsT.size(); ++a, ++i) a->setName(argNames.at(i)); BasicBlock* bb = BasicBlock::Create("entry", f); builder.SetInsertPoint(bb); return f; } void finishFunction(CEnv& cenv, CFunction f, const AType* retT, CValue ret) { if (retT->concrete()) builder.CreateRet(llVal(ret)); else builder.CreateRetVoid(); verifyFunction(*static_cast(f)); if (cenv.args.find("-g") == cenv.args.end()) opt.run(*static_cast(f)); } void eraseFunction(CEnv& cenv, CFunction f) { if (f) llFunc(f)->eraseFromParent(); } CValue compileCall(CEnv& cenv, CFunction f, const vector& args) { const vector& llArgs = *reinterpret_cast*>(&args); return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end()); } void liftCall(CEnv& cenv, AFn* fn, const AType& argsT); CValue compileLiteral(CEnv& cenv, AST* lit); CValue compilePrimitive(CEnv& cenv, APrimitive* prim); CValue compileIf(CEnv& cenv, AIf* aif); void writeModule(CEnv& cenv, std::ostream& os) { AssemblyAnnotationWriter writer; module->print(os, &writer); } const string call(CEnv& cenv, CFunction f, AType* retT) { void* fp = engine->getPointerToFunction(llFunc(f)); const Type* t = llType(retT); THROW_IF(!fp, Cursor(), "unable to get function pointer"); THROW_IF(!t, Cursor(), "function with non-concrete return type called"); std::stringstream ss; if (t == Type::Int32Ty) ss << ((int32_t (*)())fp)(); else if (t == Type::FloatTy) ss << showpoint << ((float (*)())fp)(); else if (t == Type::Int1Ty) ss << (((bool (*)())fp)() ? "#t" : "#f"); else if (t != Type::VoidTy) ss << ((void* (*)())fp)(); return ss.str(); } Module* module; ExecutionEngine* engine; IRBuilder<> builder; Function* alloc; ExistingModuleProvider emp; FunctionPassManager opt; }; Engine* tuplr_new_llvm_engine() { return new LLVMEngine(); } /*************************************************************************** * Code Generation * ***************************************************************************/ CValue LLVMEngine::compileLiteral(CEnv& cenv, AST* lit) { ALiteral* ilit = dynamic_cast*>(lit); if (ilit) return ConstantInt::get(Type::Int32Ty, ilit->val, true); ALiteral* flit = dynamic_cast*>(lit); if (flit) return ConstantFP::get(Type::FloatTy, flit->val); ALiteral* blit = dynamic_cast*>(lit); if (blit) return ConstantFP::get(Type::FloatTy, blit->val); throw Error(lit->loc, "Unknown literal type"); } void LLVMEngine::liftCall(CEnv& cenv, AFn* fn, const AType& argsT) { TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(fn); assert(gt != cenv.tenv.genericTypes.end()); LLVMEngine* engine = reinterpret_cast(cenv.engine()); AType* genericType = new AType(*gt->second); AType* thisType = genericType; Subst argsSubst; // Build and apply substitution to get concrete type for this call if (!genericType->concrete()) { argsSubst = cenv.tenv.buildSubst(genericType, argsT); thisType = argsSubst.apply(genericType)->as(); } THROW_IF(!thisType->concrete(), fn->loc, string("call has non-concrete type %1%\n") + thisType->str()); Object::pool.addRoot(thisType); if (fn->impls.find(thisType)) return; ATuple* protT = thisType->at(1)->as(); vector argNames; for (size_t i = 0; i < fn->prot()->size(); ++i) argNames.push_back(fn->prot()->at(i)->str()); // Write function declaration const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name; Function* f = llFunc(cenv.engine()->startFunction(cenv, name, thisType->at(thisType->size()-1)->to(), *protT, argNames)); cenv.push(); Subst oldSubst = cenv.tsubst; cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, fn->subst)); //#define EXPLICIT_STACK_FRAMES 1 #ifdef EXPLICIT_STACK_FRAMES vector types; types.push_back(Type::Int8Ty); types.push_back(Type::Int8Ty); size_t s = 16; // stack frame size in bits #endif // Bind argument values in CEnv vector args; AFn::const_iterator p = fn->prot()->begin(); size_t i = 0; for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) { AType* t = protT->at(i)->as(); const Type* lt = llType(t); THROW_IF(!lt, fn->loc, "untyped parameter\n"); cenv.def((*p)->as(), *p, t, &*a); #ifdef EXPLICIT_STACK_FRAMES types.push_back(lt); s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8)); #endif } #ifdef EXPLICIT_STACK_FRAMES IRBuilder<> builder = engine->builder; // Scan out definitions for (size_t i = 0; i < size(); ++i) { ADef* def = at(i)->to(); if (def) { const Type* lt = llType(cenv.type(def->at(2))); THROW_IF(!lt, loc, "untyped definition\n"); types.push_back(lt); s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8)); } } // Create stack frame StructType* frameT = StructType::get(types, false); Value* tag = ConstantInt::get(Type::Int8Ty, GC::TAG_FRAME); Value* frameSize = ConstantInt::get(Type::Int32Ty, s / 8); Value* frame = builder.CreateCall2(engine->alloc, frameSize, tag, "frame"); Value* framePtr = builder.CreateBitCast(frame, PointerType::get(frameT, 0)); // Bind parameter values in stack frame i = 2; for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++i) { Value* v = builder.CreateStructGEP(framePtr, i, "arg"); builder.CreateStore(&*a, v); } #endif // Write function body try { // Define value first for recursion cenv.precompile(fn, f); fn->impls.push_back(make_pair(thisType, f)); CValue retVal = NULL; for (size_t i = 2; i < fn->size(); ++i) retVal = cenv.compile(fn->at(i)); cenv.engine()->finishFunction(cenv, f, cenv.type(fn->at(fn->size() - 1)), retVal); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function cenv.pop(); throw e; } cenv.tsubst = oldSubst; cenv.pop(); } CValue LLVMEngine::compileIf(CEnv& cenv, AIf* aif) { typedef vector< pair > Branches; LLVMEngine* engine = reinterpret_cast(cenv.engine()); Function* parent = engine->builder.GetInsertBlock()->getParent(); BasicBlock* mergeBB = BasicBlock::Create("endif"); BasicBlock* nextBB = NULL; Branches branches; for (size_t i = 1; i < aif->size() - 1; i += 2) { Value* condV = llVal(cenv.compile(aif->at(i))); BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str()); nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str()); engine->builder.CreateCondBr(condV, thenBB, nextBB); // Emit then block for this condition parent->getBasicBlockList().push_back(thenBB); engine->builder.SetInsertPoint(thenBB); Value* thenV = llVal(cenv.compile(aif->at(i + 1))); engine->builder.CreateBr(mergeBB); branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock())); parent->getBasicBlockList().push_back(nextBB); engine->builder.SetInsertPoint(nextBB); } // Emit final else block engine->builder.SetInsertPoint(nextBB); Value* elseV = llVal(cenv.compile(aif->at(aif->size() - 1))); engine->builder.CreateBr(mergeBB); branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock())); // Emit merge block (Phi node) parent->getBasicBlockList().push_back(mergeBB); engine->builder.SetInsertPoint(mergeBB); PHINode* pn = engine->builder.CreatePHI(llType(cenv.type(aif)), "ifval"); FOREACH(Branches::iterator, i, branches) pn->addIncoming(i->first, i->second); return pn; } CValue LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim) { LLVMEngine* engine = reinterpret_cast(cenv.engine()); Value* a = llVal(cenv.compile(prim->at(1))); Value* b = llVal(cenv.compile(prim->at(2))); bool isFloat = cenv.type(prim->at(1))->str() == "Float"; const string n = prim->at(0)->to()->str(); // Binary arithmetic operations Instruction::BinaryOps op = (Instruction::BinaryOps)0; if (n == "+") op = Instruction::Add; if (n == "-") op = Instruction::Sub; if (n == "*") op = Instruction::Mul; if (n == "and") op = Instruction::And; if (n == "or") op = Instruction::Or; if (n == "xor") op = Instruction::Xor; if (n == "/") op = isFloat ? Instruction::FDiv : Instruction::SDiv; if (n == "%") op = isFloat ? Instruction::FRem : Instruction::SRem; if (op != 0) { Value* val = engine->builder.CreateBinOp(op, a, b); for (size_t i = 3; i < prim->size(); ++i) val = engine->builder.CreateBinOp(op, val, llVal(cenv.compile(prim->at(i)))); return val; } // Comparison operations CmpInst::Predicate pred = (CmpInst::Predicate)0; if (n == "=") pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ; if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ; if (n == ">") pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT; if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE; if (n == "<") pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT; if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE; if (pred != 0) { if (isFloat) return engine->builder.CreateFCmp(pred, a, b); else return engine->builder.CreateICmp(pred, a, b); } throw Error(prim->loc, "unknown primitive"); }