/* Tuplr: A programming language * Copyright (C) 2008-2009 David Robillard * * Tuplr is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Tuplr is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Tuplr. If not, see . */ #include #include #include #include "llvm/Analysis/Verifier.h" #include "llvm/Assembly/AsmAnnotationWriter.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" #include "tuplr.hpp" using namespace llvm; using namespace std; using boost::format; inline Value* LLVal(CValue v) { return static_cast(v); } inline Function* LLFunc(CFunction f) { return static_cast(f); } struct LLVMEngine { LLVMEngine(); Module* module; ExecutionEngine* engine; IRBuilder<> builder; }; static const Type* lltype(const AType* t) { switch (t->kind) { case AType::VAR: throw Error((format("non-compilable type `%1%'") % t->str()).str(), t->loc); return NULL; case AType::PRIM: if (t->at(0)->str() == "Bool") return Type::Int1Ty; if (t->at(0)->str() == "Int") return Type::Int32Ty; if (t->at(0)->str() == "Float") return Type::FloatTy; throw Error(string("Unknown primitive type `") + t->str() + "'"); case AType::EXPR: if (t->at(0)->str() == "Pair") { vector types; for (size_t i = 1; i < t->size(); ++i) types.push_back(lltype(t->at(i)->to())); return PointerType::get(StructType::get(types, false), 0); } } return NULL; // not reached } static LLVMEngine* llengine(CEnv& cenv) { return reinterpret_cast(cenv.engine()); } LLVMEngine::LLVMEngine() : module(new Module("tuplr")) , engine(ExecutionEngine::create(module)) { } struct CEnv::PImpl { PImpl(LLVMEngine* e) : engine(e), module(e->module), emp(module), opt(&emp) { // Set up the optimizer pipeline: const TargetData* target = engine->engine->getTargetData(); opt.add(new TargetData(*target)); // Register target arch opt.add(createInstructionCombiningPass()); // Simple optimizations opt.add(createReassociatePass()); // Reassociate expressions opt.add(createGVNPass()); // Eliminate Common Subexpressions opt.add(createCFGSimplificationPass()); // Simplify control flow } LLVMEngine* engine; Module* module; ExistingModuleProvider emp; FunctionPassManager opt; }; CEnv::CEnv(PEnv& p, TEnv& t, CEngine e, ostream& os, ostream& es) : out(os), err(es), penv(p), tenv(t), symID(0), alloc(0), _pimpl(new PImpl((LLVMEngine*)e)) { } CEnv::~CEnv() { delete _pimpl; } CEngine CEnv::engine() { return _pimpl->engine; } CValue CEnv::compile(AST* obj) { CValue* v = vals.ref(obj); return (v && *v) ? *v : vals.def(obj, obj->compile(*this)); } void CEnv::optimise(CFunction f) { verifyFunction(*static_cast(f)); _pimpl->opt.run(*static_cast(f)); } void CEnv::write(std::ostream& os) { AssemblyAnnotationWriter writer; _pimpl->engine->module->print(os, &writer); } #define LITERAL(CT, NAME, COMPILED) \ template<> CValue \ ALiteral::compile(CEnv& cenv) { return (COMPILED); } \ template<> void \ ALiteral::constrain(TEnv& tenv, Constraints& c) const { c.constrain(tenv, this, tenv.named(NAME)); } /// Literal template instantiations LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)) LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)) LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)) static Function* compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& protT, const vector argNames=vector()) { Function::LinkageTypes linkage = Function::ExternalLinkage; vector cprot; for (size_t i = 0; i < protT.size(); ++i) { AType* at = protT.at(i)->as(); if (!lltype(at)) throw Error("function parameter is untyped"); cprot.push_back(lltype(at)); } if (!retT) throw Error("function return is untyped"); FunctionType* fT = FunctionType::get(static_cast(retT), cprot, false); Function* f = Function::Create(fT, linkage, name, llengine(cenv)->module); if (f->getName() != name) { f->eraseFromParent(); throw Error("function redefined"); } // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); if (!argNames.empty()) for (size_t i = 0; i != protT.size(); ++a, ++i) a->setName(argNames.at(i)); BasicBlock* bb = BasicBlock::Create("entry", f); llengine(cenv)->builder.SetInsertPoint(bb); return f; } /*************************************************************************** * Code Generation * ***************************************************************************/ CValue ASymbol::compile(CEnv& cenv) { return cenv.vals.ref(this); } void AClosure::lift(CEnv& cenv) { AType* type = cenv.type(this); if (funcs.find(type) || !type->concrete()) return; ATuple* protT = type->at(1)->as(); vector argsT; for (size_t i = 0; i < protT->size(); ++i) argsT.push_back(protT->at(i)->as()); liftCall(cenv, argsT); } void AClosure::liftCall(CEnv& cenv, const vector& argsT) { TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(this); assert(gt != cenv.tenv.genericTypes.end()); AType* genericType = new AType(*gt->second); AType* thisType = genericType; Subst argsSubst; if (!thisType->concrete()) { // Find type and build substitution assert(argsT.size() == prot()->size()); ATuple* genericProtT = genericType->at(1)->as(); for (size_t i = 0; i < argsT.size(); ++i) argsSubst[genericProtT->at(i)->to()] = argsT.at(i)->to(); thisType = argsSubst.apply(genericType)->as(); if (!thisType->concrete()) throw Error("unable to resolve concrete type for function", loc); } else { thisType = genericType; } if (funcs.find(thisType)) return; ATuple* protT = thisType->at(1)->as(); // Write function declaration string name = this->name == "" ? cenv.gensym("_fn") : this->name; Function* f = compileFunction(cenv, name, lltype(thisType->at(thisType->size()-1)->to()), *protT); cenv.push(); Subst oldSubst = cenv.tsubst; cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, *subst)); // Bind argument values in CEnv vector args; const_iterator p = prot()->begin(); size_t i = 0; for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) cenv.def((*p)->as(), *p, protT->at(i++)->as(), &*a); // Write function body try { // Define value first for recursion cenv.precompile(this, f); funcs.push_back(make_pair(thisType, f)); CValue retVal = NULL; for (size_t i = 2; i < size(); ++i) retVal = cenv.compile(at(i)); llengine(cenv)->builder.CreateRet(LLVal(retVal)); // Finish function cenv.optimise(LLFunc(f)); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function cenv.pop(); throw e; } cenv.tsubst = oldSubst; cenv.pop(); } CValue AClosure::compile(CEnv& cenv) { return NULL; } void ACall::lift(CEnv& cenv) { AClosure* c = cenv.tenv.resolve(at(0))->to(); vector argsT; // Lift arguments for (size_t i = 1; i < size(); ++i) { at(i)->lift(cenv); argsT.push_back(cenv.type(at(i))); } if (!c) return; // Primitive if (c->prot()->size() < size() - 1) throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc); if (c->prot()->size() > size() - 1) throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc); c->liftCall(cenv, argsT); // Lift called closure } CValue ACall::compile(CEnv& cenv) { AClosure* c = cenv.tenv.resolve(at(0))->to(); if (!c) return NULL; // Primitive AType* protT = new AType(loc, NULL); for (size_t i = 1; i < size(); ++i) protT->push_back(cenv.type(at(i))); TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(c); assert(gt != cenv.tenv.genericTypes.end()); const AType* polyT = gt->second; AType* fnT = new AType(loc, cenv.penv.sym("Fn"), protT, polyT->at(polyT->size()-1), 0); Function* f = (Function*)c->funcs.find(fnT); if (!f) throw Error("callee failed to compile", loc); vector params(size() - 1); for (size_t i = 1; i < size(); ++i) params[i-1] = LLVal(cenv.compile(at(i))); return llengine(cenv)->builder.CreateCall(f, params.begin(), params.end()); } void ADefinition::lift(CEnv& cenv) { // Define stub first for recursion cenv.def(sym(), at(2), cenv.type(at(2)), NULL); AClosure* c = at(2)->to(); if (c) c->name = sym()->str(); at(2)->lift(cenv); } CValue ADefinition::compile(CEnv& cenv) { // Define stub first for recursion cenv.def(sym(), at(2), cenv.type(at(2)), NULL); CValue val = cenv.compile(at(size() - 1)); cenv.vals.def(sym(), val); return val; } CValue AIf::compile(CEnv& cenv) { typedef vector< pair > Branches; Function* parent = llengine(cenv)->builder.GetInsertBlock()->getParent(); BasicBlock* mergeBB = BasicBlock::Create("endif"); BasicBlock* nextBB = NULL; Branches branches; for (size_t i = 1; i < size() - 1; i += 2) { Value* condV = LLVal(cenv.compile(at(i))); BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str()); nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str()); llengine(cenv)->builder.CreateCondBr(condV, thenBB, nextBB); // Emit then block for this condition parent->getBasicBlockList().push_back(thenBB); llengine(cenv)->builder.SetInsertPoint(thenBB); Value* thenV = LLVal(cenv.compile(at(i+1))); llengine(cenv)->builder.CreateBr(mergeBB); branches.push_back(make_pair(thenV, llengine(cenv)->builder.GetInsertBlock())); parent->getBasicBlockList().push_back(nextBB); llengine(cenv)->builder.SetInsertPoint(nextBB); } // Emit final else block llengine(cenv)->builder.SetInsertPoint(nextBB); Value* elseV = LLVal(cenv.compile(at(size() - 1))); llengine(cenv)->builder.CreateBr(mergeBB); branches.push_back(make_pair(elseV, llengine(cenv)->builder.GetInsertBlock())); // Emit merge block (Phi node) parent->getBasicBlockList().push_back(mergeBB); llengine(cenv)->builder.SetInsertPoint(mergeBB); PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.type(this)), "ifval"); FOREACH(Branches::iterator, i, branches) pn->addIncoming(i->first, i->second); return pn; } CValue APrimitive::compile(CEnv& cenv) { Value* a = LLVal(cenv.compile(at(1))); Value* b = LLVal(cenv.compile(at(2))); bool isFloat = cenv.type(at(1))->str() == "Float"; const string n = at(0)->to()->str(); // Binary arithmetic operations Instruction::BinaryOps op = (Instruction::BinaryOps)0; if (n == "+") op = Instruction::Add; if (n == "-") op = Instruction::Sub; if (n == "*") op = Instruction::Mul; if (n == "and") op = Instruction::And; if (n == "or") op = Instruction::Or; if (n == "xor") op = Instruction::Xor; if (n == "/") op = isFloat ? Instruction::FDiv : Instruction::SDiv; if (n == "%") op = isFloat ? Instruction::FRem : Instruction::SRem; if (op != 0) { Value* val = llengine(cenv)->builder.CreateBinOp(op, a, b); for (size_t i = 3; i < size(); ++i) val = llengine(cenv)->builder.CreateBinOp(op, val, LLVal(cenv.compile(at(i)))); return val; } // Comparison operations CmpInst::Predicate pred = (CmpInst::Predicate)0; if (n == "=") pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ; if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ; if (n == ">") pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT; if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE; if (n == "<") pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT; if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE; if (pred != 0) { if (isFloat) return llengine(cenv)->builder.CreateFCmp(pred, a, b); else return llengine(cenv)->builder.CreateICmp(pred, a, b); } throw Error("unknown primitive", loc); } AType* AConsCall::functionType(CEnv& cenv) { ATuple* protTypes = new ATuple(loc, cenv.type(at(1)), cenv.type(at(2)), 0); AType* cellType = new AType(loc, cenv.penv.sym("Pair"), cenv.type(at(1)), cenv.type(at(2)), 0); return new AType(at(0)->loc, cenv.penv.sym("Fn"), protTypes, cellType, 0); } void AConsCall::lift(CEnv& cenv) { AType* funcType = functionType(cenv); if (funcs.find(functionType(cenv))) return; ACall::lift(cenv); ATuple* protT = new ATuple(loc, cenv.type(at(1)), cenv.type(at(2)), 0); vector types; size_t sz = 0; for (size_t i = 1; i < size(); ++i) { const Type* t = lltype(cenv.type(at(i))); types.push_back(t); sz += t->getPrimitiveSizeInBits(); } sz = (sz % 8 == 0) ? sz / 8 : sz / 8 + 1; llvm::IRBuilder<>& builder = llengine(cenv)->builder; StructType* sT = StructType::get(types, false); Type* pT = PointerType::get(sT, 0); // Write function declaration vector argNames; argNames.push_back("car"); argNames.push_back("cdr"); Function* func = compileFunction(cenv, cenv.gensym("cons"), pT, *protT, argNames); Value* mem = builder.CreateCall(LLVal(cenv.alloc), ConstantInt::get(Type::Int32Ty, sz), "mem"); Value* cell = builder.CreateBitCast(mem, pT, "cell"); Value* s = builder.CreateGEP(cell, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* carP = builder.CreateStructGEP(s, 0, "car"); Value* cdrP = builder.CreateStructGEP(s, 1, "cdr"); Function::arg_iterator ai = func->arg_begin(); Value& carArg = *ai++; Value& cdrArg = *ai++; builder.CreateStore(&carArg, carP); builder.CreateStore(&cdrArg, cdrP); builder.CreateRet(cell); cenv.optimise(func); funcs.push_back(make_pair(funcType, func)); } CValue AConsCall::compile(CEnv& cenv) { vector params(size() - 1); for (size_t i = 1; i < size(); ++i) params[i-1] = LLVal(cenv.compile(at(i))); return llengine(cenv)->builder.CreateCall(LLFunc(funcs.find(functionType(cenv))), params.begin(), params.end()); } CValue ACarCall::compile(CEnv& cenv) { AST* arg = cenv.tenv.resolve(at(1)); Value* sP = LLVal(cenv.compile(arg)); Value* s = llengine(cenv)->builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* carP = llengine(cenv)->builder.CreateStructGEP(s, 0, "car"); return llengine(cenv)->builder.CreateLoad(carP); } CValue ACdrCall::compile(CEnv& cenv) { AST* arg = cenv.tenv.resolve(at(1)); Value* sP = LLVal(cenv.compile(arg)); Value* s = llengine(cenv)->builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* cdrP = llengine(cenv)->builder.CreateStructGEP(s, 1, "cdr"); return llengine(cenv)->builder.CreateLoad(cdrP); } /*************************************************************************** * EVAL/REPL * ***************************************************************************/ const string call(AType* retT, void* fp) { std::stringstream ss; if (lltype(retT) == Type::Int32Ty) ss << ((int32_t (*)())fp)(); else if (lltype(retT) == Type::FloatTy) ss << showpoint << ((float (*)())fp)(); else if (lltype(retT) == Type::Int1Ty) ss << (((bool (*)())fp)() ? "#t" : "#f"); else ss << ((void* (*)())fp)(); return ss.str(); } int eval(CEnv& cenv, const string& name, istream& is) { AST* result = NULL; AType* resultType = NULL; list< pair > exprs; Cursor cursor(name); try { while (true) { SExp exp = readExpression(cursor, is); if (exp.type == SExp::LIST && exp.empty()) break; result = cenv.penv.parse(exp); // Parse input Constraints c; result->constrain(cenv.tenv, c); // Constrain types cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints resultType = cenv.type(result); result->lift(cenv); // Lift functions exprs.push_back(make_pair(exp, result)); } const Type* ctype = lltype(resultType); if (!ctype) throw Error("body has non-compilable type", cursor); // Create function for top-level of program Function* f = compileFunction(cenv, "main", ctype, ATuple(cursor)); // Compile all expressions into it Value* val = NULL; for (list< pair >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) val = LLVal(cenv.compile(i->second)); // Finish function llengine(cenv)->builder.CreateRet(val); cenv.optimise(f); cenv.out << call(resultType, llengine(cenv)->engine->getPointerToFunction(f)) << " : " << resultType << endl; } catch (Error& e) { cenv.err << e.what() << endl; return 1; } return 0; } int repl(CEnv& cenv) { while (1) { cenv.out << "() "; cenv.out.flush(); Cursor cursor("(stdin)"); try { SExp exp = readExpression(cursor, std::cin); if (exp.type == SExp::LIST && exp.empty()) break; AST* body = cenv.penv.parse(exp); // Parse input Constraints c; body->constrain(cenv.tenv, c); // Constrain types Subst oldSubst = cenv.tsubst; cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints AType* bodyT = cenv.type(body); if (!bodyT) throw Error("call to untyped body", cursor); body->lift(cenv); if (lltype(bodyT)) { // Create anonymous function to insert code into Function* f = compileFunction(cenv, cenv.gensym("_repl"), lltype(bodyT), ATuple(cursor)); try { Value* retVal = LLVal(cenv.compile(body)); llengine(cenv)->builder.CreateRet(retVal); // Finish function cenv.optimise(f); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function throw e; } cenv.out << call(bodyT, llengine(cenv)->engine->getPointerToFunction(f)); } else { cenv.out << "; " << cenv.compile(body); } cenv.out << " : " << cenv.type(body) << endl; cenv.tsubst = oldSubst; } catch (Error& e) { cenv.err << e.what() << endl; } } return 0; } CEnv* newCenv(PEnv& penv, TEnv& tenv) { LLVMEngine* engine = new LLVMEngine(); CEnv* cenv = new CEnv(penv, tenv, engine); // Host provided allocation primitive prototypes std::vector argsT(1, Type::Int32Ty); FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false); cenv->alloc = Function::Create(funcT, Function::ExternalLinkage, "malloc", engine->module); return cenv; }