From e59c3164288378f89131699eee19884e5f87711e Mon Sep 17 00:00:00 2001 From: David Robillard Date: Tue, 6 Oct 2009 20:32:29 +0000 Subject: Move llvm.cpp back to where it came from... git-svn-id: http://svn.drobilla.net/resp/tuplr@198 ad02d1e2-f140-0410-9f75-f8b11f17cedd --- Makefile | 6 +- src/llvm.cpp | 424 +++++++++++++++++++++++++++++++++++++++++++++++++++++ src/tuplr_llvm.cpp | 424 ----------------------------------------------------- 3 files changed, 427 insertions(+), 427 deletions(-) create mode 100644 src/llvm.cpp delete mode 100644 src/tuplr_llvm.cpp diff --git a/Makefile b/Makefile index f25b489..42f034d 100644 --- a/Makefile +++ b/Makefile @@ -22,12 +22,12 @@ OBJECTS = \ build/cps.o \ build/gc.o \ build/lex.o \ + build/llvm.o \ build/parse.o \ build/pprint.o \ build/repl.o \ build/tlsf.o \ build/tuplr.o \ - build/tuplr_llvm.o \ build/unify.o LIBS = \ @@ -42,8 +42,8 @@ build/%.o: src/%.cpp src/tuplr.hpp build/tlsf.o: src/tlsf.c src/tlsf.h gcc $(CFLAGS) -o $@ -c $< -build/tuplr_llvm.o: src/tuplr_llvm.cpp src/tuplr.hpp - g++ -c $(CXXFLAGS) $(LLVM_CXXFLAGS) -o $@ src/tuplr_llvm.cpp +build/llvm.o: src/llvm.cpp src/tuplr.hpp + g++ -c $(CXXFLAGS) $(LLVM_CXXFLAGS) -o $@ src/llvm.cpp build/%.so: src/%.cpp src/tuplr.hpp g++ -shared $(CXXFLAGS) -o $@ $^ diff --git a/src/llvm.cpp b/src/llvm.cpp new file mode 100644 index 0000000..e2f7f1a --- /dev/null +++ b/src/llvm.cpp @@ -0,0 +1,424 @@ +/* Tuplr: A programming language + * Copyright (C) 2008-2009 David Robillard + * + * Tuplr is free software: you can redistribute it and/or modify it under + * the terms of the GNU Affero General Public License as published by the + * Free Software Foundation, either version 3 of the License, or (at your + * option) any later version. + * + * Tuplr is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Tuplr. If not, see . + */ + +/** @file + * @brief Compile AST to LLVM IR + * + * Compilation pass functions (lift/compile) that require direct use of LLVM + * specific things are implemented here. Generic compilation pass functions + * are implemented in compile.cpp. + */ + +#include +#include +#include +#include "llvm/Analysis/Verifier.h" +#include "llvm/Assembly/AsmAnnotationWriter.h" +#include "llvm/DerivedTypes.h" +#include "llvm/ExecutionEngine/ExecutionEngine.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/ModuleProvider.h" +#include "llvm/PassManager.h" +#include "llvm/Support/IRBuilder.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Scalar.h" +#include "tuplr.hpp" + +using namespace llvm; +using namespace std; +using boost::format; + +static inline Value* llVal(CValue v) { return static_cast(v); } +static inline Function* llFunc(CFunction f) { return static_cast(f); } + +static const Type* +llType(const AType* t) +{ + if (t->kind == AType::PRIM) { + if (t->at(0)->str() == "Nothing") return Type::VoidTy; + if (t->at(0)->str() == "Bool") return Type::Int1Ty; + if (t->at(0)->str() == "Int") return Type::Int32Ty; + if (t->at(0)->str() == "Float") return Type::FloatTy; + throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'"); + } else if (t->kind == AType::EXPR && t->at(0)->str() == "Fn") { + const AType* retT = t->at(2)->as(); + if (!llType(retT)) + return NULL; + + vector cprot; + const ATuple* prot = t->at(1)->to(); + for (size_t i = 0; i < prot->size(); ++i) { + const AType* at = prot->at(i)->to(); + const Type* lt = llType(at); + if (lt) + cprot.push_back(lt); + else + return NULL; + } + + FunctionType* fT = FunctionType::get(llType(retT), cprot, false); + return PointerType::get(fT, 0); + } + return NULL; // non-primitive type +} + + +/*************************************************************************** + * LLVM Engine * + ***************************************************************************/ + +struct LLVMEngine : public Engine { + LLVMEngine() + : module(new Module("tuplr")) + , engine(ExecutionEngine::create(module)) + , emp(module) + , opt(&emp) + { + // Set up optimiser pipeline + const TargetData* target = engine->getTargetData(); + opt.add(new TargetData(*target)); // Register target arch + opt.add(createInstructionCombiningPass()); // Simple optimizations + opt.add(createReassociatePass()); // Reassociate expressions + opt.add(createGVNPass()); // Eliminate Common Subexpressions + opt.add(createCFGSimplificationPass()); // Simplify control flow + + // Declare host provided allocation primitive + std::vector argsT(1, Type::Int32Ty); // unsigned size + argsT.push_back(Type::Int8Ty); // char tag + FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false); + alloc = Function::Create(funcT, Function::ExternalLinkage, + "tuplr_gc_allocate", module); + } + + CFunction startFunction(CEnv& cenv, + const std::string& name, const AType* retT, const ATuple& argsT, + const vector argNames) + { + Function::LinkageTypes linkage = Function::ExternalLinkage; + + vector cprot; + FOREACH(ATuple::const_iterator, i, argsT) { + AType* at = (*i)->as(); + THROW_IF(!llType(at), Cursor(), string("non-concrete parameter :: ") + + at->str()) + cprot.push_back(llType(at)); + } + + THROW_IF(!llType(retT), Cursor(), "return has non-concrete type"); + FunctionType* fT = FunctionType::get(llType(retT), cprot, false); + Function* f = Function::Create(fT, linkage, name, module); + + // Note f->getName() may be different from name + // however LLVM chooses to mangle is fine, we keep a pointer + + // Set argument names in generated code + Function::arg_iterator a = f->arg_begin(); + if (!argNames.empty()) + for (size_t i = 0; i != argsT.size(); ++a, ++i) + a->setName(argNames.at(i)); + + BasicBlock* bb = BasicBlock::Create("entry", f); + builder.SetInsertPoint(bb); + + return f; + } + + void finishFunction(CEnv& cenv, CFunction f, const AType* retT, CValue ret) { + if (retT->concrete()) + builder.CreateRet(llVal(ret)); + else + builder.CreateRetVoid(); + + verifyFunction(*static_cast(f)); + if (cenv.args.find("-g") == cenv.args.end()) + opt.run(*static_cast(f)); + } + + void eraseFunction(CEnv& cenv, CFunction f) { + if (f) + llFunc(f)->eraseFromParent(); + } + + CValue compileCall(CEnv& cenv, CFunction f, const vector& args) { + const vector& llArgs = *reinterpret_cast*>(&args); + return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end()); + } + + void liftCall(CEnv& cenv, AFn* fn, const AType& argsT); + + CValue compileLiteral(CEnv& cenv, AST* lit); + CValue compilePrimitive(CEnv& cenv, APrimitive* prim); + CValue compileIf(CEnv& cenv, AIf* aif); + + void writeModule(CEnv& cenv, std::ostream& os) { + AssemblyAnnotationWriter writer; + module->print(os, &writer); + } + + const string call(CEnv& cenv, CFunction f, AType* retT) { + void* fp = engine->getPointerToFunction(llFunc(f)); + const Type* t = llType(retT); + THROW_IF(!fp, Cursor(), "unable to get function pointer"); + THROW_IF(!t, Cursor(), "function with non-concrete return type called"); + + std::stringstream ss; + if (t == Type::Int32Ty) + ss << ((int32_t (*)())fp)(); + else if (t == Type::FloatTy) + ss << showpoint << ((float (*)())fp)(); + else if (t == Type::Int1Ty) + ss << (((bool (*)())fp)() ? "#t" : "#f"); + else if (t != Type::VoidTy) + ss << ((void* (*)())fp)(); + return ss.str(); + } + + Module* module; + ExecutionEngine* engine; + IRBuilder<> builder; + Function* alloc; + ExistingModuleProvider emp; + FunctionPassManager opt; +}; + +Engine* +tuplr_new_llvm_engine() +{ + return new LLVMEngine(); +} + +/*************************************************************************** + * Code Generation * + ***************************************************************************/ + +CValue +LLVMEngine::compileLiteral(CEnv& cenv, AST* lit) +{ + ALiteral* ilit = dynamic_cast*>(lit); + if (ilit) + return ConstantInt::get(Type::Int32Ty, ilit->val, true); + + ALiteral* flit = dynamic_cast*>(lit); + if (flit) + return ConstantFP::get(Type::FloatTy, flit->val); + + ALiteral* blit = dynamic_cast*>(lit); + if (blit) + return ConstantFP::get(Type::FloatTy, blit->val); + + throw Error(lit->loc, "Unknown literal type"); +} + +void +LLVMEngine::liftCall(CEnv& cenv, AFn* fn, const AType& argsT) +{ + TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(fn); + assert(gt != cenv.tenv.genericTypes.end()); + LLVMEngine* engine = reinterpret_cast(cenv.engine()); + AType* genericType = new AType(*gt->second); + AType* thisType = genericType; + Subst argsSubst; + + // Build and apply substitution to get concrete type for this call + if (!genericType->concrete()) { + argsSubst = cenv.tenv.buildSubst(genericType, argsT); + thisType = argsSubst.apply(genericType)->as(); + } + + THROW_IF(!thisType->concrete(), fn->loc, + string("call has non-concrete type %1%\n") + thisType->str()); + + Object::pool.addRoot(thisType); + if (fn->impls.find(thisType)) + return; + + ATuple* protT = thisType->at(1)->as(); + + vector argNames; + for (size_t i = 0; i < fn->prot()->size(); ++i) + argNames.push_back(fn->prot()->at(i)->str()); + + // Write function declaration + const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name; + Function* f = llFunc(cenv.engine()->startFunction(cenv, name, + thisType->at(thisType->size()-1)->to(), + *protT, argNames)); + + cenv.push(); + Subst oldSubst = cenv.tsubst; + cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, fn->subst)); + +//#define EXPLICIT_STACK_FRAMES 1 + +#ifdef EXPLICIT_STACK_FRAMES + vector types; + types.push_back(Type::Int8Ty); + types.push_back(Type::Int8Ty); + size_t s = 16; // stack frame size in bits +#endif + + // Bind argument values in CEnv + vector args; + AFn::const_iterator p = fn->prot()->begin(); + size_t i = 0; + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) { + AType* t = protT->at(i)->as(); + const Type* lt = llType(t); + THROW_IF(!lt, fn->loc, "untyped parameter\n"); + cenv.def((*p)->as(), *p, t, &*a); +#ifdef EXPLICIT_STACK_FRAMES + types.push_back(lt); + s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8)); +#endif + } + +#ifdef EXPLICIT_STACK_FRAMES + IRBuilder<> builder = engine->builder; + + // Scan out definitions + for (size_t i = 0; i < size(); ++i) { + ADef* def = at(i)->to(); + if (def) { + const Type* lt = llType(cenv.type(def->at(2))); + THROW_IF(!lt, loc, "untyped definition\n"); + types.push_back(lt); + s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8)); + } + } + + // Create stack frame + StructType* frameT = StructType::get(types, false); + Value* tag = ConstantInt::get(Type::Int8Ty, GC::TAG_FRAME); + Value* frameSize = ConstantInt::get(Type::Int32Ty, s / 8); + Value* frame = builder.CreateCall2(engine->alloc, frameSize, tag, "frame"); + Value* framePtr = builder.CreateBitCast(frame, PointerType::get(frameT, 0)); + + // Bind parameter values in stack frame + i = 2; + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++i) { + Value* v = builder.CreateStructGEP(framePtr, i, "arg"); + builder.CreateStore(&*a, v); + } +#endif + + // Write function body + try { + // Define value first for recursion + cenv.precompile(fn, f); + fn->impls.push_back(make_pair(thisType, f)); + CValue retVal = NULL; + for (size_t i = 2; i < fn->size(); ++i) + retVal = cenv.compile(fn->at(i)); + cenv.engine()->finishFunction(cenv, f, cenv.type(fn->at(fn->size() - 1)), retVal); + } catch (Error& e) { + f->eraseFromParent(); // Error reading body, remove function + cenv.pop(); + throw e; + } + cenv.tsubst = oldSubst; + cenv.pop(); +} + +CValue +LLVMEngine::compileIf(CEnv& cenv, AIf* aif) +{ + typedef vector< pair > Branches; + LLVMEngine* engine = reinterpret_cast(cenv.engine()); + Function* parent = engine->builder.GetInsertBlock()->getParent(); + BasicBlock* mergeBB = BasicBlock::Create("endif"); + BasicBlock* nextBB = NULL; + Branches branches; + for (size_t i = 1; i < aif->size() - 1; i += 2) { + Value* condV = llVal(cenv.compile(aif->at(i))); + BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str()); + + nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str()); + + engine->builder.CreateCondBr(condV, thenBB, nextBB); + + // Emit then block for this condition + parent->getBasicBlockList().push_back(thenBB); + engine->builder.SetInsertPoint(thenBB); + Value* thenV = llVal(cenv.compile(aif->at(i + 1))); + engine->builder.CreateBr(mergeBB); + branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock())); + + parent->getBasicBlockList().push_back(nextBB); + engine->builder.SetInsertPoint(nextBB); + } + + // Emit final else block + engine->builder.SetInsertPoint(nextBB); + Value* elseV = llVal(cenv.compile(aif->at(aif->size() - 1))); + engine->builder.CreateBr(mergeBB); + branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock())); + + // Emit merge block (Phi node) + parent->getBasicBlockList().push_back(mergeBB); + engine->builder.SetInsertPoint(mergeBB); + PHINode* pn = engine->builder.CreatePHI(llType(cenv.type(aif)), "ifval"); + + FOREACH(Branches::iterator, i, branches) + pn->addIncoming(i->first, i->second); + + return pn; +} + +CValue +LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim) +{ + LLVMEngine* engine = reinterpret_cast(cenv.engine()); + Value* a = llVal(cenv.compile(prim->at(1))); + Value* b = llVal(cenv.compile(prim->at(2))); + bool isFloat = cenv.type(prim->at(1))->str() == "Float"; + const string n = prim->at(0)->to()->str(); + + // Binary arithmetic operations + Instruction::BinaryOps op = (Instruction::BinaryOps)0; + if (n == "+") op = Instruction::Add; + if (n == "-") op = Instruction::Sub; + if (n == "*") op = Instruction::Mul; + if (n == "and") op = Instruction::And; + if (n == "or") op = Instruction::Or; + if (n == "xor") op = Instruction::Xor; + if (n == "/") op = isFloat ? Instruction::FDiv : Instruction::SDiv; + if (n == "%") op = isFloat ? Instruction::FRem : Instruction::SRem; + if (op != 0) { + Value* val = engine->builder.CreateBinOp(op, a, b); + for (size_t i = 3; i < prim->size(); ++i) + val = engine->builder.CreateBinOp(op, val, llVal(cenv.compile(prim->at(i)))); + return val; + } + + // Comparison operations + CmpInst::Predicate pred = (CmpInst::Predicate)0; + if (n == "=") pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ; + if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ; + if (n == ">") pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT; + if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE; + if (n == "<") pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT; + if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE; + if (pred != 0) { + if (isFloat) + return engine->builder.CreateFCmp(pred, a, b); + else + return engine->builder.CreateICmp(pred, a, b); + } + + throw Error(prim->loc, "unknown primitive"); +} diff --git a/src/tuplr_llvm.cpp b/src/tuplr_llvm.cpp deleted file mode 100644 index e2f7f1a..0000000 --- a/src/tuplr_llvm.cpp +++ /dev/null @@ -1,424 +0,0 @@ -/* Tuplr: A programming language - * Copyright (C) 2008-2009 David Robillard - * - * Tuplr is free software: you can redistribute it and/or modify it under - * the terms of the GNU Affero General Public License as published by the - * Free Software Foundation, either version 3 of the License, or (at your - * option) any later version. - * - * Tuplr is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY - * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General - * Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with Tuplr. If not, see . - */ - -/** @file - * @brief Compile AST to LLVM IR - * - * Compilation pass functions (lift/compile) that require direct use of LLVM - * specific things are implemented here. Generic compilation pass functions - * are implemented in compile.cpp. - */ - -#include -#include -#include -#include "llvm/Analysis/Verifier.h" -#include "llvm/Assembly/AsmAnnotationWriter.h" -#include "llvm/DerivedTypes.h" -#include "llvm/ExecutionEngine/ExecutionEngine.h" -#include "llvm/Instructions.h" -#include "llvm/Module.h" -#include "llvm/ModuleProvider.h" -#include "llvm/PassManager.h" -#include "llvm/Support/IRBuilder.h" -#include "llvm/Target/TargetData.h" -#include "llvm/Transforms/Scalar.h" -#include "tuplr.hpp" - -using namespace llvm; -using namespace std; -using boost::format; - -static inline Value* llVal(CValue v) { return static_cast(v); } -static inline Function* llFunc(CFunction f) { return static_cast(f); } - -static const Type* -llType(const AType* t) -{ - if (t->kind == AType::PRIM) { - if (t->at(0)->str() == "Nothing") return Type::VoidTy; - if (t->at(0)->str() == "Bool") return Type::Int1Ty; - if (t->at(0)->str() == "Int") return Type::Int32Ty; - if (t->at(0)->str() == "Float") return Type::FloatTy; - throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'"); - } else if (t->kind == AType::EXPR && t->at(0)->str() == "Fn") { - const AType* retT = t->at(2)->as(); - if (!llType(retT)) - return NULL; - - vector cprot; - const ATuple* prot = t->at(1)->to(); - for (size_t i = 0; i < prot->size(); ++i) { - const AType* at = prot->at(i)->to(); - const Type* lt = llType(at); - if (lt) - cprot.push_back(lt); - else - return NULL; - } - - FunctionType* fT = FunctionType::get(llType(retT), cprot, false); - return PointerType::get(fT, 0); - } - return NULL; // non-primitive type -} - - -/*************************************************************************** - * LLVM Engine * - ***************************************************************************/ - -struct LLVMEngine : public Engine { - LLVMEngine() - : module(new Module("tuplr")) - , engine(ExecutionEngine::create(module)) - , emp(module) - , opt(&emp) - { - // Set up optimiser pipeline - const TargetData* target = engine->getTargetData(); - opt.add(new TargetData(*target)); // Register target arch - opt.add(createInstructionCombiningPass()); // Simple optimizations - opt.add(createReassociatePass()); // Reassociate expressions - opt.add(createGVNPass()); // Eliminate Common Subexpressions - opt.add(createCFGSimplificationPass()); // Simplify control flow - - // Declare host provided allocation primitive - std::vector argsT(1, Type::Int32Ty); // unsigned size - argsT.push_back(Type::Int8Ty); // char tag - FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false); - alloc = Function::Create(funcT, Function::ExternalLinkage, - "tuplr_gc_allocate", module); - } - - CFunction startFunction(CEnv& cenv, - const std::string& name, const AType* retT, const ATuple& argsT, - const vector argNames) - { - Function::LinkageTypes linkage = Function::ExternalLinkage; - - vector cprot; - FOREACH(ATuple::const_iterator, i, argsT) { - AType* at = (*i)->as(); - THROW_IF(!llType(at), Cursor(), string("non-concrete parameter :: ") - + at->str()) - cprot.push_back(llType(at)); - } - - THROW_IF(!llType(retT), Cursor(), "return has non-concrete type"); - FunctionType* fT = FunctionType::get(llType(retT), cprot, false); - Function* f = Function::Create(fT, linkage, name, module); - - // Note f->getName() may be different from name - // however LLVM chooses to mangle is fine, we keep a pointer - - // Set argument names in generated code - Function::arg_iterator a = f->arg_begin(); - if (!argNames.empty()) - for (size_t i = 0; i != argsT.size(); ++a, ++i) - a->setName(argNames.at(i)); - - BasicBlock* bb = BasicBlock::Create("entry", f); - builder.SetInsertPoint(bb); - - return f; - } - - void finishFunction(CEnv& cenv, CFunction f, const AType* retT, CValue ret) { - if (retT->concrete()) - builder.CreateRet(llVal(ret)); - else - builder.CreateRetVoid(); - - verifyFunction(*static_cast(f)); - if (cenv.args.find("-g") == cenv.args.end()) - opt.run(*static_cast(f)); - } - - void eraseFunction(CEnv& cenv, CFunction f) { - if (f) - llFunc(f)->eraseFromParent(); - } - - CValue compileCall(CEnv& cenv, CFunction f, const vector& args) { - const vector& llArgs = *reinterpret_cast*>(&args); - return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end()); - } - - void liftCall(CEnv& cenv, AFn* fn, const AType& argsT); - - CValue compileLiteral(CEnv& cenv, AST* lit); - CValue compilePrimitive(CEnv& cenv, APrimitive* prim); - CValue compileIf(CEnv& cenv, AIf* aif); - - void writeModule(CEnv& cenv, std::ostream& os) { - AssemblyAnnotationWriter writer; - module->print(os, &writer); - } - - const string call(CEnv& cenv, CFunction f, AType* retT) { - void* fp = engine->getPointerToFunction(llFunc(f)); - const Type* t = llType(retT); - THROW_IF(!fp, Cursor(), "unable to get function pointer"); - THROW_IF(!t, Cursor(), "function with non-concrete return type called"); - - std::stringstream ss; - if (t == Type::Int32Ty) - ss << ((int32_t (*)())fp)(); - else if (t == Type::FloatTy) - ss << showpoint << ((float (*)())fp)(); - else if (t == Type::Int1Ty) - ss << (((bool (*)())fp)() ? "#t" : "#f"); - else if (t != Type::VoidTy) - ss << ((void* (*)())fp)(); - return ss.str(); - } - - Module* module; - ExecutionEngine* engine; - IRBuilder<> builder; - Function* alloc; - ExistingModuleProvider emp; - FunctionPassManager opt; -}; - -Engine* -tuplr_new_llvm_engine() -{ - return new LLVMEngine(); -} - -/*************************************************************************** - * Code Generation * - ***************************************************************************/ - -CValue -LLVMEngine::compileLiteral(CEnv& cenv, AST* lit) -{ - ALiteral* ilit = dynamic_cast*>(lit); - if (ilit) - return ConstantInt::get(Type::Int32Ty, ilit->val, true); - - ALiteral* flit = dynamic_cast*>(lit); - if (flit) - return ConstantFP::get(Type::FloatTy, flit->val); - - ALiteral* blit = dynamic_cast*>(lit); - if (blit) - return ConstantFP::get(Type::FloatTy, blit->val); - - throw Error(lit->loc, "Unknown literal type"); -} - -void -LLVMEngine::liftCall(CEnv& cenv, AFn* fn, const AType& argsT) -{ - TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(fn); - assert(gt != cenv.tenv.genericTypes.end()); - LLVMEngine* engine = reinterpret_cast(cenv.engine()); - AType* genericType = new AType(*gt->second); - AType* thisType = genericType; - Subst argsSubst; - - // Build and apply substitution to get concrete type for this call - if (!genericType->concrete()) { - argsSubst = cenv.tenv.buildSubst(genericType, argsT); - thisType = argsSubst.apply(genericType)->as(); - } - - THROW_IF(!thisType->concrete(), fn->loc, - string("call has non-concrete type %1%\n") + thisType->str()); - - Object::pool.addRoot(thisType); - if (fn->impls.find(thisType)) - return; - - ATuple* protT = thisType->at(1)->as(); - - vector argNames; - for (size_t i = 0; i < fn->prot()->size(); ++i) - argNames.push_back(fn->prot()->at(i)->str()); - - // Write function declaration - const string name = (fn->name == "") ? cenv.penv.gensymstr("_fn") : fn->name; - Function* f = llFunc(cenv.engine()->startFunction(cenv, name, - thisType->at(thisType->size()-1)->to(), - *protT, argNames)); - - cenv.push(); - Subst oldSubst = cenv.tsubst; - cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, fn->subst)); - -//#define EXPLICIT_STACK_FRAMES 1 - -#ifdef EXPLICIT_STACK_FRAMES - vector types; - types.push_back(Type::Int8Ty); - types.push_back(Type::Int8Ty); - size_t s = 16; // stack frame size in bits -#endif - - // Bind argument values in CEnv - vector args; - AFn::const_iterator p = fn->prot()->begin(); - size_t i = 0; - for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) { - AType* t = protT->at(i)->as(); - const Type* lt = llType(t); - THROW_IF(!lt, fn->loc, "untyped parameter\n"); - cenv.def((*p)->as(), *p, t, &*a); -#ifdef EXPLICIT_STACK_FRAMES - types.push_back(lt); - s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8)); -#endif - } - -#ifdef EXPLICIT_STACK_FRAMES - IRBuilder<> builder = engine->builder; - - // Scan out definitions - for (size_t i = 0; i < size(); ++i) { - ADef* def = at(i)->to(); - if (def) { - const Type* lt = llType(cenv.type(def->at(2))); - THROW_IF(!lt, loc, "untyped definition\n"); - types.push_back(lt); - s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8)); - } - } - - // Create stack frame - StructType* frameT = StructType::get(types, false); - Value* tag = ConstantInt::get(Type::Int8Ty, GC::TAG_FRAME); - Value* frameSize = ConstantInt::get(Type::Int32Ty, s / 8); - Value* frame = builder.CreateCall2(engine->alloc, frameSize, tag, "frame"); - Value* framePtr = builder.CreateBitCast(frame, PointerType::get(frameT, 0)); - - // Bind parameter values in stack frame - i = 2; - for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++i) { - Value* v = builder.CreateStructGEP(framePtr, i, "arg"); - builder.CreateStore(&*a, v); - } -#endif - - // Write function body - try { - // Define value first for recursion - cenv.precompile(fn, f); - fn->impls.push_back(make_pair(thisType, f)); - CValue retVal = NULL; - for (size_t i = 2; i < fn->size(); ++i) - retVal = cenv.compile(fn->at(i)); - cenv.engine()->finishFunction(cenv, f, cenv.type(fn->at(fn->size() - 1)), retVal); - } catch (Error& e) { - f->eraseFromParent(); // Error reading body, remove function - cenv.pop(); - throw e; - } - cenv.tsubst = oldSubst; - cenv.pop(); -} - -CValue -LLVMEngine::compileIf(CEnv& cenv, AIf* aif) -{ - typedef vector< pair > Branches; - LLVMEngine* engine = reinterpret_cast(cenv.engine()); - Function* parent = engine->builder.GetInsertBlock()->getParent(); - BasicBlock* mergeBB = BasicBlock::Create("endif"); - BasicBlock* nextBB = NULL; - Branches branches; - for (size_t i = 1; i < aif->size() - 1; i += 2) { - Value* condV = llVal(cenv.compile(aif->at(i))); - BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str()); - - nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str()); - - engine->builder.CreateCondBr(condV, thenBB, nextBB); - - // Emit then block for this condition - parent->getBasicBlockList().push_back(thenBB); - engine->builder.SetInsertPoint(thenBB); - Value* thenV = llVal(cenv.compile(aif->at(i + 1))); - engine->builder.CreateBr(mergeBB); - branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock())); - - parent->getBasicBlockList().push_back(nextBB); - engine->builder.SetInsertPoint(nextBB); - } - - // Emit final else block - engine->builder.SetInsertPoint(nextBB); - Value* elseV = llVal(cenv.compile(aif->at(aif->size() - 1))); - engine->builder.CreateBr(mergeBB); - branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock())); - - // Emit merge block (Phi node) - parent->getBasicBlockList().push_back(mergeBB); - engine->builder.SetInsertPoint(mergeBB); - PHINode* pn = engine->builder.CreatePHI(llType(cenv.type(aif)), "ifval"); - - FOREACH(Branches::iterator, i, branches) - pn->addIncoming(i->first, i->second); - - return pn; -} - -CValue -LLVMEngine::compilePrimitive(CEnv& cenv, APrimitive* prim) -{ - LLVMEngine* engine = reinterpret_cast(cenv.engine()); - Value* a = llVal(cenv.compile(prim->at(1))); - Value* b = llVal(cenv.compile(prim->at(2))); - bool isFloat = cenv.type(prim->at(1))->str() == "Float"; - const string n = prim->at(0)->to()->str(); - - // Binary arithmetic operations - Instruction::BinaryOps op = (Instruction::BinaryOps)0; - if (n == "+") op = Instruction::Add; - if (n == "-") op = Instruction::Sub; - if (n == "*") op = Instruction::Mul; - if (n == "and") op = Instruction::And; - if (n == "or") op = Instruction::Or; - if (n == "xor") op = Instruction::Xor; - if (n == "/") op = isFloat ? Instruction::FDiv : Instruction::SDiv; - if (n == "%") op = isFloat ? Instruction::FRem : Instruction::SRem; - if (op != 0) { - Value* val = engine->builder.CreateBinOp(op, a, b); - for (size_t i = 3; i < prim->size(); ++i) - val = engine->builder.CreateBinOp(op, val, llVal(cenv.compile(prim->at(i)))); - return val; - } - - // Comparison operations - CmpInst::Predicate pred = (CmpInst::Predicate)0; - if (n == "=") pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ; - if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ; - if (n == ">") pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT; - if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE; - if (n == "<") pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT; - if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE; - if (pred != 0) { - if (isFloat) - return engine->builder.CreateFCmp(pred, a, b); - else - return engine->builder.CreateICmp(pred, a, b); - } - - throw Error(prim->loc, "unknown primitive"); -} -- cgit v1.2.1