/* Resp: A programming language * Copyright (C) 2008-2009 David Robillard * * Resp is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Resp is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Resp. If not, see . */ /** @file * @brief Compile to LLVM IR and/or execute via JIT */ #define __STDC_LIMIT_MACROS 1 #define __STDC_CONSTANT_MACROS 1 #include #include #include #include #include #include #include "llvm/Analysis/Verifier.h" #include "llvm/Assembly/AssemblyAnnotationWriter.h" #include "llvm/DefaultPasses.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JIT.h" #include "llvm/ExecutionEngine/JITMemoryManager.h" #include "llvm/Instructions.h" #include "llvm/LLVMContext.h" #include "llvm/Module.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_os_ostream.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Value.h" #include "resp.hpp" using namespace llvm; using namespace std; using boost::format; struct IfRecord { IfRecord(LLVMContext& context, unsigned labelIndex, Type* t) : type(t) , mergeBB(BasicBlock::Create(context, "endif")) , thenBB(BasicBlock::Create(context, (format("then%1%") % labelIndex).str())) , elseBB(BasicBlock::Create(context, (format("else%1%") % labelIndex).str())) , thenV(NULL) , elseV(NULL) , nMergeBranches(0) {} Type* type; BasicBlock* mergeBB; BasicBlock* thenBB; BasicBlock* elseBB; Value* thenV; Value* elseV; unsigned nMergeBranches; }; /** LLVM Engine (Compiler and JIT) */ struct LLVMEngine : public Engine { LLVMEngine(); virtual ~LLVMEngine(); CFunc compileProt(CEnv& cenv, const string& name, const ATuple* args, const ATuple* type); CFunc startFn(CEnv& cenv, const string& name, const ATuple* args, const ATuple* type); void finishFn(CEnv& cenv, CVal ret, const AST* retT); CFunc getFn(CEnv& cenv, const std::string& name); void eraseFn(CEnv& cenv, CFunc f); CVal compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector& args); CVal compileCast(CEnv& cenv, CVal v, const AST* t); CVal compileCons(CEnv& cenv, const char* tname, const ATuple* type, CVal rtti, const vector& fields); CVal compileDot(CEnv& cenv, CVal tup, int32_t index); CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t); CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v); CVal compileIfStart(CEnv& cenv, const AST* cond, const AST* type); CVal compileIfThen(CEnv& cenv, CVal thenV); CVal compileIfElse(CEnv& cenv, CVal elseV); CVal compileIfEnd(CEnv& cenv); CVal compileLiteral(CEnv& cenv, const AST* lit); CVal compilePrimitive(CEnv& cenv, const ATuple* prim); CVal compileString(CEnv& cenv, const char* str); CType compileType(CEnv& cenv, const std::string& name, const AST* exp); void writeModule(CEnv& cenv, std::ostream& os); const string call(CEnv& cenv, CFunc f, const AST* retT); private: void pushFnArgs(CEnv& cenv, const ATuple* prot, const ATuple* type, CFunc f); void appendBlock(LLVMEngine* engine, Function* function, BasicBlock* block) { function->getBasicBlockList().push_back(block); engine->builder.SetInsertPoint(block); } inline Value* llVal(CVal v) { return static_cast(v); } inline Function* llFunc(CFunc f) { return static_cast(f); } Type* llType(const AST* t, const char* name=NULL); LLVMContext context; Module* module; ExecutionEngine* engine; IRBuilder<> builder; Function* alloc; FunctionPassManager* fnOpt; PassManager* modOpt; CType objectT; StructType* opaqueT; std::string opaqueName; typedef std::map CTypes; CTypes compiledTypes; typedef std::stack IfStack; IfStack if_stack; CFunc currentFn; unsigned labelIndex; }; LLVMEngine::LLVMEngine() : builder(context) , opaqueT(NULL) , currentFn(NULL) , labelIndex(1) { InitializeNativeTarget(); module = new Module("resp", context); engine = EngineBuilder(module).create(); fnOpt = new FunctionPassManager(module); modOpt = new PassManager(); // Set up optimisers PassManagerBuilder pmb; pmb.OptLevel = 3; pmb.populateFunctionPassManager(*fnOpt); pmb.populateModulePassManager(*modOpt); // Declare host provided allocation primitive std::vector argsT(1, Type::getInt32Ty(context)); // unsigned size FunctionType* funcT = FunctionType::get( PointerType::get(Type::getInt8Ty(context), 0), argsT, false); alloc = Function::Create( funcT, Function::ExternalLinkage, "__resp_alloc", module); // Build Object type (tag only, binary compatible with any constructed thing) vector ctypes; ctypes.push_back(PointerType::get(Type::getInt8Ty(context), 0)); // RTTI StructType* cObjectT = StructType::create(context, ctypes, "Object", false); objectT = cObjectT; } LLVMEngine::~LLVMEngine() { delete engine; delete fnOpt; delete modOpt; } Type* LLVMEngine::llType(const AST* t, const char* name) { if (t == NULL) { return NULL; } else if (AType::is_name(t)) { const std::string sym(t->as_symbol()->sym()); if (sym == "Nothing") return Type::getVoidTy(context); if (sym == "Bool") return Type::getInt1Ty(context); if (sym == "Int") return Type::getInt32Ty(context); if (sym == "Float") return Type::getFloatTy(context); if (sym == "String") return PointerType::get(Type::getInt8Ty(context), 0); if (sym == "Symbol") return PointerType::get(Type::getInt8Ty(context), 0); if (sym == opaqueName) { THROW_IF(!opaqueT, t->loc, "Broken recursive type"); return PointerType::getUnqual(opaqueT); } CTypes::const_iterator i = compiledTypes.find(sym); if (i != compiledTypes.end()) return i->second; cerr << "WARNING: No low-level type for " << t << endl; return NULL; } THROW_IF(!isupper(t->as_tuple()->fst()->str()[0]), t->loc, "Lower-case type expression"); // Function type if (is_form(t, "Fn")) { ATuple::const_iterator i = t->as_tuple()->begin(); const ATuple* protT = (*++i)->to_tuple(); const AST* retT = (*++i); if (!llType(retT)) return NULL; vector cprot; FOREACHP(ATuple::const_iterator, i, protT) { Type* lt = llType(*i); if (!lt) return NULL; cprot.push_back(lt); } return PointerType::getUnqual(FunctionType::get(llType(retT), cprot, false)); } // Struct type StructType* ret = NULL; vector ctypes; if (!name) { const ASymbol* tag = t->as_tuple()->fst()->as_symbol(); if (tag->str() != "Tup" && tag->str() != "Closure") { name = tag->str().c_str(); } } if (name) { CTypes::const_iterator i = compiledTypes.find(name); if (i != compiledTypes.end()) { ret = (StructType*)((PointerType*)i->second)->getContainedType(0); } } if (ret && !ret->isOpaque()) { return PointerType::getUnqual(ret); } // Define opaque type to stand for name in recursive type body if (name) { THROW_IF(opaqueT, t->loc, "Nested recursive types"); opaqueT = (ret) ? ret : StructType::create(context, name); opaqueName = name; } // Get member types ctypes.push_back(PointerType::get(Type::getInt8Ty(context), 0)); // RTTI for (ATuple::const_iterator i = t->as_tuple()->iter_at(1); i != t->as_tuple()->end(); ++i) { Type* lt = llType(*i); if (!lt) return NULL; ctypes.push_back(lt); } if (name) { // Resolve recursive type opaqueT->setBody(ctypes); ret = opaqueT; opaqueT = NULL; opaqueName = ""; } else { ret = StructType::get(context, ctypes, false); } return PointerType::getUnqual(ret); } /** Convert a size in bits to bytes, rounding up as necessary */ static inline size_t bitsToBytes(size_t bits) { return ((bits % 8 == 0) ? bits : (((bits / 8) + 1) * 8)) / 8; } CVal LLVMEngine::compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector& args) { vector llArgs(*reinterpret_cast*>(&args)); return builder.CreateCall(llFunc(f), llArgs); } CVal LLVMEngine::compileCast(CEnv& cenv, CVal v, const AST* t) { const ATuple* tup = t->to_tuple(); string name; if (tup) { const ASymbol* head = tup->fst()->to_symbol(); if (head && head->str() != "Fn" && isupper(head->str()[0])) { name = head->str(); } } else if (t->to_symbol()) { name = t->to_symbol()->str(); } return builder.CreateBitCast(llVal(v), (Type*)compileType(cenv, name, cenv.resolveType(t)), "cast"); } CVal LLVMEngine::compileCons(CEnv& cenv, const char* tname, const ATuple* type, CVal rtti, const vector& fields) { // Find size of memory required size_t s = engine->getTargetData()->getTypeSizeInBits( PointerType::get(Type::getInt8Ty(context), 0)); assert(type->begin() != type->end()); for (ATuple::const_iterator i = type->iter_at(1); i != type->end(); ++i) s += engine->getTargetData()->getTypeSizeInBits( (Type*)compileType(cenv, (*i)->str(), *i)); // Allocate struct const std::string name = type->fst()->str(); Value* structSize = ConstantInt::get(Type::getInt32Ty(context), bitsToBytes(s)); Value* mem = builder.CreateCall(alloc, structSize, name + "_mem"); StructType* consT = module->getTypeByName(tname); Value* structPtr = NULL; if (consT) { // Named type structPtr = builder.CreateBitCast(mem, PointerType::getUnqual(consT), name); } else { // Anonymous type structPtr = builder.CreateBitCast(mem, llType(type), name); } // Set struct fields if (rtti) builder.CreateStore((Value*)rtti, builder.CreateStructGEP(structPtr, 0, "rtti")); size_t i = 1; ATuple::const_iterator t = type->iter_at(1); for (vector::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i, ++t) { Value* val = llVal(*f); Value* field = builder.CreateStructGEP( structPtr, i, (format("tup%1%") % i).str().c_str()); if ((*t)->to_tuple()) val = builder.CreateBitCast(val, llType(*t), "objPtr"); builder.CreateStore(val, field); } return structPtr; } CVal LLVMEngine::compileDot(CEnv& cenv, CVal tup, int32_t index) { Value* ptr = builder.CreateStructGEP(llVal(tup), index, "dotPtr"); return builder.CreateLoad(ptr, 0, "dotVal"); } CVal LLVMEngine::compileLiteral(CEnv& cenv, const AST* lit) { switch (lit->tag()) { case T_BOOL: return ConstantInt::get(Type::getInt1Ty(context), ((const ALiteral*)lit)->val); case T_FLOAT: return ConstantFP::get(Type::getFloatTy(context), ((const ALiteral*)lit)->val); case T_INT32: return ConstantInt::get(Type::getInt32Ty(context), ((const ALiteral*)lit)->val, true); default: throw Error(lit->loc, "Unknown literal type"); } } CVal LLVMEngine::compileString(CEnv& cenv, const char* str) { return builder.CreateGlobalStringPtr(str); } CType LLVMEngine::compileType(CEnv& cenv, const std::string& name, const AST* expr) { if (!name.empty()) { CTypes::const_iterator i = compiledTypes.find(name); if (i != compiledTypes.end()) { if (!i->second->isPointerTy() || !((StructType*)(((PointerType*)i->second)->getContainedType(0)))->isOpaque()) { return i->second; } } } Type* const type = (expr) ? llType(expr, name.c_str()) // Definition : PointerType::getUnqual(StructType::create(context, name)); // Forward declaration if (!name.empty()) compiledTypes.insert(make_pair(name, type)); return type; } CFunc LLVMEngine::compileProt( CEnv& cenv, const std::string& name, const ATuple* args, const ATuple* type) { const ATuple* argsT = type->prot(); const AST* retT = type->list_last(); Function::LinkageTypes linkage = Function::ExternalLinkage; vector cprot; FOREACHP(ATuple::const_iterator, i, argsT) { const CType iT = ((*i)->to_symbol()) ? compileType(cenv, (*i)->str(), cenv.resolveType(*i)) : compileType(cenv, (*i)->as_tuple()->fst()->str(), *i); THROW_IF(!iT, Cursor(), string("non-concrete parameter :: ") + (*i)->str()); cprot.push_back((Type*)iT); } THROW_IF(!llType(retT), Cursor(), (format("return has non-concrete type `%1%'") % retT->str()).str()); const string llName = (name == "") ? cenv.penv.gensymstr("_fn") : name; FunctionType* fT = FunctionType::get(llType(retT), cprot, false); Function* f = Function::Create(fT, linkage, llName, module); // Note f->getName() may be different from llName // however LLVM chooses to mangle is fine, we keep a pointer // Set argument names in generated code if (args) { Function::arg_iterator a = f->arg_begin(); for (ATuple::const_iterator i = args->begin(); i != args->end(); ++a, ++i) a->setName((*i)->as_symbol()->sym()); } // Define function in the environment so any calls that get compiled before // the definition will resolve correctly (e.g. parent calls from a child fn) cenv.def(cenv.penv.sym(name), NULL, type, f); return f; } CFunc LLVMEngine::startFn( CEnv& cenv, const std::string& name, const ATuple* args, const ATuple* type) { // Use forward declaration if it exists Function* f = module->getFunction(name); if (!f) { f = (Function*)compileProt(cenv, name, args, type); } // Start the function body BasicBlock* bb = BasicBlock::Create(context, "entry", f); builder.SetInsertPoint(bb); currentFn = f; cenv.def(cenv.penv.sym(name), NULL, type, f); pushFnArgs(cenv, args, type, f); return f; } void LLVMEngine::pushFnArgs(CEnv& cenv, const ATuple* prot, const ATuple* type, CFunc cfunc) { cenv.push(); if (!prot) return; const ATuple* argsT = type->prot(); Function* f = llFunc(cfunc); // Bind argument values in CEnv ATuple::const_iterator p = prot->begin(); ATuple::const_iterator pT = argsT->begin(); assert(prot->size() == argsT->size()); assert(prot->size() == f->num_args()); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++pT) { const AST* t = *pT;//cenv.resolveType(*pT); // THROW_IF(!llType(t), (*p)->loc, "untyped parameter\n"); cenv.def((*p)->as_symbol(), *p, t, &*a); } } void LLVMEngine::finishFn(CEnv& cenv, CVal ret, const AST* retT) { CFunc f = currentFn; if (retT->str() == "Nothing") builder.CreateRetVoid(); else builder.CreateRet(builder.CreateBitCast(llVal(ret), llType(retT), "ret")); if (verifyFunction(*static_cast(f), llvm::PrintMessageAction)) { module->dump(); throw Error(Cursor(), "Broken module"); } if (cenv.args.find("-g") == cenv.args.end()) fnOpt->run(*static_cast(f)); currentFn = NULL; } CFunc LLVMEngine::getFn(CEnv& cenv, const std::string& name) { return module->getFunction(name); } void LLVMEngine::eraseFn(CEnv& cenv, CFunc f) { if (f) llFunc(f)->eraseFromParent(); } CVal LLVMEngine::compileIfStart(CEnv& cenv, const AST* cond, const AST* type) { LLVMEngine* engine = reinterpret_cast(cenv.engine()); Function* parent = engine->builder.GetInsertBlock()->getParent(); IfRecord* rec = new IfRecord(context, ++labelIndex, llType(type)); if_stack.push(rec); engine->builder.CreateCondBr(llVal(resp_compile(cenv, cond)), rec->thenBB, rec->elseBB); // Start then block appendBlock(engine, parent, rec->thenBB); return NULL; } CVal LLVMEngine::compileIfThen(CEnv& cenv, CVal thenV) { LLVMEngine* engine = reinterpret_cast(cenv.engine()); Function* parent = engine->builder.GetInsertBlock()->getParent(); IfRecord* rec = if_stack.top(); rec->thenV = llVal(thenV); // Finish then block engine->builder.CreateBr(rec->mergeBB); rec->thenBB = engine->builder.GetInsertBlock(); // Start else block appendBlock(engine, parent, rec->elseBB); return NULL; } CVal LLVMEngine::compileIfElse(CEnv& cenv, CVal elseV) { LLVMEngine* engine = reinterpret_cast(cenv.engine()); IfRecord* rec = if_stack.top(); rec->elseV = llVal(elseV); if (elseV) { engine->builder.CreateBr(rec->mergeBB); ++rec->nMergeBranches; } else { engine->builder.CreateUnreachable(); } rec->elseBB = engine->builder.GetInsertBlock(); return NULL; } CVal LLVMEngine::compileIfEnd(CEnv& cenv) { LLVMEngine* engine = reinterpret_cast(cenv.engine()); Function* parent = engine->builder.GetInsertBlock()->getParent(); IfRecord* rec = if_stack.top(); // Emit merge block (Phi node) appendBlock(engine, parent, rec->mergeBB); ++rec->nMergeBranches; PHINode* pn = engine->builder.CreatePHI(rec->type, rec->nMergeBranches, "ifval"); pn->addIncoming(rec->thenV, rec->thenBB); if (rec->elseV) pn->addIncoming(rec->elseV, rec->elseBB); if_stack.pop(); delete rec; return pn; } CVal LLVMEngine::compilePrimitive(CEnv& cenv, const ATuple* prim) { ATuple::const_iterator i = prim->iter_at(1); LLVMEngine* engine = reinterpret_cast(cenv.engine()); bool isFloat = (cenv.type(prim)->str() == "Float"); Value* a = llVal(resp_compile(cenv, *i++)); Value* b = llVal(resp_compile(cenv, *i++)); const string n = prim->fst()->to_symbol()->str(); // Binary arithmetic operations Instruction::BinaryOps op = (Instruction::BinaryOps)0; if (n == "+") op = Instruction::Add; if (n == "-") op = Instruction::Sub; if (n == "*") op = isFloat ? Instruction::FMul : Instruction::Mul; if (n == "and") op = Instruction::And; if (n == "or") op = Instruction::Or; if (n == "xor") op = Instruction::Xor; if (n == "/") op = isFloat ? Instruction::FDiv : Instruction::SDiv; if (n == "%") op = isFloat ? Instruction::FRem : Instruction::SRem; if (op != 0) { Value* val = engine->builder.CreateBinOp(op, a, b); while (i != prim->end()) val = engine->builder.CreateBinOp(op, val, llVal(resp_compile(cenv, *i++))); return val; } // Comparison operations CmpInst::Predicate pred = (CmpInst::Predicate)0; if (n == "=") pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ; if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ; if (n == ">") pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT; if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE; if (n == "<") pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT; if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE; if (pred != 0) { if (isFloat) return engine->builder.CreateFCmp(pred, a, b); else return engine->builder.CreateICmp(pred, a, b); } throw Error(prim->loc, "unknown primitive"); } CVal LLVMEngine::compileGlobalSet(CEnv& cenv, const string& sym, CVal val, const AST* type) { LLVMEngine* engine = reinterpret_cast(cenv.engine()); GlobalVariable* global = new GlobalVariable(*module, llType(type), false, GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), sym); Value* valPtr = builder.CreateBitCast(llVal(val), llType(type), "globalPtr"); engine->builder.CreateStore(valPtr, global); return global; } CVal LLVMEngine::compileGlobalGet(CEnv& cenv, const string& sym, CVal val) { LLVMEngine* engine = reinterpret_cast(cenv.engine()); return engine->builder.CreateLoad(llVal(val), sym + "Ptr"); } void LLVMEngine::writeModule(CEnv& cenv, std::ostream& os) { if (cenv.args.find("-g") == cenv.args.end()) modOpt->run(*module); AssemblyAnnotationWriter writer; llvm::raw_os_ostream raw_stream(os); module->print(raw_stream, &writer); } const string LLVMEngine::call(CEnv& cenv, CFunc f, const AST* retT) { void* fp = engine->getPointerToFunction(llFunc(f)); const Type* t = llType(retT); THROW_IF(!fp, Cursor(), "unable to get function pointer"); THROW_IF(!t, Cursor(), "function with non-concrete return type called"); std::stringstream ss; if (t == Type::getInt32Ty(context)) { ss << ((int32_t (*)())fp)(); } else if (t == Type::getFloatTy(context)) { ss << showpoint << ((float (*)())fp)(); } else if (t == Type::getInt1Ty(context)) { ss << (((bool (*)())fp)() ? "#t" : "#f"); } else if (retT->str() == "String") { const std::string s(((char* (*)())fp)()); ss << "\""; for (std::string::const_iterator i = s.begin(); i != s.end(); ++i) { switch (*i) { case '\"': case '\\': ss << '\\'; default: ss << *i; break; } } ss << "\""; } else if (retT->str() == "Symbol") { const std::string s(((char* (*)())fp)()); ss << s; } else if (t != Type::getVoidTy(context)) { ss << ((void* (*)())fp)(); } else { ((void (*)())fp)(); } return ss.str(); } Engine* resp_new_llvm_engine() { return new LLVMEngine(); }