/* Tuplr: A programming language * Copyright (C) 2008-2009 David Robillard * * Tuplr is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Tuplr is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Tuplr. If not, see . */ #include #include #include #include "llvm/Analysis/Verifier.h" #include "llvm/Assembly/AsmAnnotationWriter.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" #include "tuplr.hpp" using namespace llvm; using namespace std; using boost::format; inline Value* LLVal(CValue v) { return static_cast(v); } inline const Type* LLType(CType t) { return static_cast(t); } inline Function* LLFunc(CFunction f) { return static_cast(f); } struct CEngine { CEngine(); Module* module; ExecutionEngine* engine; IRBuilder<> builder; }; struct CArg { CArg(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; }; /*************************************************************************** * Abstract Syntax Tree * ***************************************************************************/ CType AType::type() { if (at(0)->str() == "Pair") { vector types; for (size_t i = 1; i < size(); ++i) { assert(dynamic_cast(at(i))); types.push_back(LLType(((AType*)at(i))->type())); } return PointerType::get(StructType::get(types, false), 0); } else { return ctype; } } /*************************************************************************** * Typing * ***************************************************************************/ #define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) void ASTPrimitive::constrain(TEnv& tenv) const { FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); if (OP_IS_A(arg->op, Instruction::BinaryOps)) { if (size() <= 2) throw Error((format("`%1%' requires at least 2 arguments") % at(0)->str()).str(), exp.loc); AType* tvar = tenv.type(this); for (size_t i = 1; i < size(); ++i) tenv.constrain(at(i), tvar); } else if (arg->op == Instruction::ICmp) { if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % at(0)->str()).str(), exp.loc); tenv.constrain(at(1), tenv.type(at(2))); tenv.constrain(this, tenv.named("Bool")); } else { throw Error((format("unknown primitive `%1%'") % at(0)->str()).str(), exp.loc); } } /*************************************************************************** * Code Generation * ***************************************************************************/ // Compile-Time Environment CEngine::CEngine() : module(new Module("tuplr")) , engine(ExecutionEngine::create(module)) { } struct CEnvPimpl { CEnvPimpl(CEngine& engine) : module(engine.module), emp(module), opt(&emp) { // Set up the optimizer pipeline: const TargetData* target = engine.engine->getTargetData(); opt.add(new TargetData(*target)); // Register target arch opt.add(createInstructionCombiningPass()); // Simple optimizations opt.add(createReassociatePass()); // Reassociate expressions opt.add(createGVNPass()); // Eliminate Common Subexpressions opt.add(createCFGSimplificationPass()); // Simplify control flow } Module* module; ExistingModuleProvider emp; FunctionPassManager opt; Function* alloc; }; CEnv::CEnv(PEnv& p, TEnv& t, CEngine& eng) : engine(eng), penv(p), tenv(t), symID(0), _pimpl(new CEnvPimpl(eng)) { } CEnv::~CEnv() { delete _pimpl; } CValue CEnv::compile(AST* obj) { CValue* v = vals.ref(obj); return (v) ? *v : vals.def(obj, obj->compile(*this)); } void CEnv::optimise(CFunction f) { verifyFunction(*static_cast(f)); _pimpl->opt.run(*static_cast(f)); } void CEnv::write(std::ostream& os) { AssemblyAnnotationWriter writer; engine.module->print(os, &writer); } #define LITERAL(CT, NAME, COMPILED) \ template<> CValue \ ASTLiteral::compile(CEnv& cenv) { return (COMPILED); } \ template<> void \ ASTLiteral::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } /// Literal template instantiations LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)) LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)) LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)) static Function* compileFunction(CEnv& cenv, const std::string& name, CType retT, const ASTTuple& prot, const vector argNames=vector()) { Function::LinkageTypes linkage = Function::ExternalLinkage; vector cprot; for (size_t i = 0; i < prot.size(); ++i) { AType* at = cenv.tenv.type(prot.at(i)); if (!at->type() || at->var()) throw Error("function parameter is untyped"); cprot.push_back(LLType(at->type())); } if (!retT) throw Error("function return is untyped"); FunctionType* fT = FunctionType::get(static_cast(retT), cprot, false); Function* f = Function::Create(fT, linkage, name, cenv.engine.module); if (f->getName() != name) { f->eraseFromParent(); throw Error("function redefined"); } // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); if (!argNames.empty()) for (size_t i = 0; i != prot.size(); ++a, ++i) a->setName(argNames.at(i)); else for (size_t i = 0; i != prot.size(); ++a, ++i) a->setName(prot.at(i)->str()); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.engine.builder.SetInsertPoint(bb); return f; } /*************************************************************************** * AST Code Generation * ***************************************************************************/ CValue ASTSymbol::compile(CEnv& cenv) { AST** c = cenv.code.ref(this); if (!c) throw Error((string("undefined symbol `") + cppstr + "'").c_str(), loc); return cenv.compile(*c); } void ASTClosure::lift(CEnv& cenv) { AType* type = cenv.tenv.type(this); if (!type->concrete() || funcs.find(type)) return; cenv.push(); // Write function declaration string name = this->name == "" ? cenv.gensym("_fn") : this->name; Function* f = compileFunction(cenv, name, cenv.tenv.type(at(2))->type(), *prot()); // Bind argument values in CEnv vector args; const_iterator p = prot()->begin(); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) cenv.vals.def(dynamic_cast(*p), &*a); // Write function body try { cenv.precompile(this, f); // Define our value first for recursion CValue retVal = cenv.compile(at(2)); cenv.engine.builder.CreateRet(LLVal(retVal)); // Finish function cenv.optimise(LLFunc(f)); funcs.push_back(make_pair(type, f)); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function throw e; } cenv.pop(); } CValue ASTClosure::compile(CEnv& cenv) { return funcs.find(cenv.tenv.type(this)); } void ASTCall::lift(CEnv& cenv) { ASTClosure* c = dynamic_cast(at(0)); if (!c) { AST** val = cenv.code.ref(at(0)); c = (val) ? dynamic_cast(*val) : c; } // Lift arguments for (size_t i = 1; i < size(); ++i) at(i)->lift(cenv); if (!c) return; // Extend environment with bound and typed parameters cenv.push(); if (c->prot()->size() < size() - 1) throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), exp.loc); if (c->prot()->size() > size() - 1) throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), exp.loc); for (size_t i = 1; i < size(); ++i) cenv.code.def(c->prot()->at(i-1), at(i)); c->lift(cenv); // Lift called closure cenv.pop(); // Restore environment } CValue ASTCall::compile(CEnv& cenv) { ASTClosure* c = dynamic_cast(at(0)); if (!c) { AST** val = cenv.code.ref(at(0)); c = (val) ? dynamic_cast(*val) : c; } assert(c); Function* f = dynamic_cast(LLVal(cenv.compile(c))); if (!f) throw Error("callee failed to compile", exp.loc); vector params(size() - 1); for (size_t i = 1; i < size(); ++i) params[i-1] = LLVal(cenv.compile(at(i))); return cenv.engine.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } void ASTDefinition::lift(CEnv& cenv) { if (cenv.code.ref((ASTSymbol*)at(1))) throw Error(string("`") + at(1)->str() + "' redefined", exp.loc); cenv.code.def((ASTSymbol*)at(1), at(2)); // Define first for recursion at(2)->lift(cenv); } CValue ASTDefinition::compile(CEnv& cenv) { return cenv.compile(at(2)); } CValue ASTIf::compile(CEnv& cenv) { typedef vector< pair > Branches; Function* parent = cenv.engine.builder.GetInsertBlock()->getParent(); BasicBlock* mergeBB = BasicBlock::Create("endif"); BasicBlock* nextBB = NULL; Branches branches; for (size_t i = 1; i < size() - 1; i += 2) { Value* condV = LLVal(cenv.compile(at(i))); BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str()); nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str()); cenv.engine.builder.CreateCondBr(condV, thenBB, nextBB); // Emit then block for this condition parent->getBasicBlockList().push_back(thenBB); cenv.engine.builder.SetInsertPoint(thenBB); Value* thenV = LLVal(cenv.compile(at(i+1))); cenv.engine.builder.CreateBr(mergeBB); branches.push_back(make_pair(thenV, cenv.engine.builder.GetInsertBlock())); parent->getBasicBlockList().push_back(nextBB); cenv.engine.builder.SetInsertPoint(nextBB); } // Emit final else block cenv.engine.builder.SetInsertPoint(nextBB); Value* elseV = LLVal(cenv.compile(at(size() - 1))); cenv.engine.builder.CreateBr(mergeBB); branches.push_back(make_pair(elseV, cenv.engine.builder.GetInsertBlock())); // Emit merge block (Phi node) parent->getBasicBlockList().push_back(mergeBB); cenv.engine.builder.SetInsertPoint(mergeBB); PHINode* pn = cenv.engine.builder.CreatePHI(LLType(cenv.tenv.type(this)->type()), "ifval"); FOREACH(Branches::iterator, i, branches) pn->addIncoming(i->first, i->second); return pn; } CValue ASTPrimitive::compile(CEnv& cenv) { Value* a = LLVal(cenv.compile(at(1))); Value* b = LLVal(cenv.compile(at(2))); if (OP_IS_A(arg->op, Instruction::BinaryOps)) { const Instruction::BinaryOps bo = (Instruction::BinaryOps)arg->op; if (size() == 2) return cenv.compile(at(1)); Value* val = cenv.engine.builder.CreateBinOp(bo, a, b); for (size_t i = 3; i < size(); ++i) val = cenv.engine.builder.CreateBinOp(bo, val, LLVal(cenv.compile(at(i)))); return val; } else if (arg->op == Instruction::ICmp) { bool isInt = cenv.tenv.type(at(1))->str() == "Int"; if (isInt) { return cenv.engine.builder.CreateICmp((CmpInst::Predicate)arg->arg, a, b); } else { // Translate to floating point operation switch (arg->arg) { case CmpInst::ICMP_EQ: arg->arg = CmpInst::FCMP_OEQ; break; case CmpInst::ICMP_NE: arg->arg = CmpInst::FCMP_ONE; break; case CmpInst::ICMP_SGT: arg->arg = CmpInst::FCMP_OGT; break; case CmpInst::ICMP_SGE: arg->arg = CmpInst::FCMP_OGE; break; case CmpInst::ICMP_SLT: arg->arg = CmpInst::FCMP_OLT; break; case CmpInst::ICMP_SLE: arg->arg = CmpInst::FCMP_OLE; break; default: throw Error("Unknown primitive", exp.loc); } return cenv.engine.builder.CreateFCmp((CmpInst::Predicate)arg->arg, a, b); } } throw Error("Unknown primitive", exp.loc); } AType* ASTConsCall::functionType(CEnv& cenv) { ASTTuple* protTypes = new ASTTuple(cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), NULL); AType* cellType = new AType(ASTTuple(cenv.penv.sym("Pair"), cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), NULL)); return new AType(ASTTuple(cenv.penv.sym("Fn"), protTypes, cellType, NULL)); } void ASTConsCall::lift(CEnv& cenv) { AType* funcType = functionType(cenv); if (funcs.find(functionType(cenv))) return; ASTCall::lift(cenv); ASTTuple* prot = new ASTTuple(at(1), at(2), NULL); vector types; size_t sz = 0; for (size_t i = 1; i < size(); ++i) { const Type* t = LLType(cenv.tenv.type(at(i))->type()); types.push_back(t); sz += t->getPrimitiveSizeInBits(); } sz = (sz % 8 == 0) ? sz / 8 : sz / 8 + 1; llvm::IRBuilder<>& builder = cenv.engine.builder; StructType* sT = StructType::get(types, false); Type* pT = PointerType::get(sT, 0); // Write function declaration vector argNames; argNames.push_back("car"); argNames.push_back("cdr"); Function* func = compileFunction(cenv, cenv.gensym("cons"), pT, *prot, argNames); Value* mem = builder.CreateCall(LLVal(cenv.alloc), ConstantInt::get(Type::Int32Ty, sz), "mem"); Value* cell = builder.CreateBitCast(mem, pT, "cell"); Value* s = builder.CreateGEP(cell, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* carP = builder.CreateStructGEP(s, 0, "car"); Value* cdrP = builder.CreateStructGEP(s, 1, "cdr"); Function::arg_iterator ai = func->arg_begin(); Value& carArg = *ai++; Value& cdrArg = *ai++; builder.CreateStore(&carArg, carP); builder.CreateStore(&cdrArg, cdrP); builder.CreateRet(cell); cenv.optimise(func); funcs.push_back(make_pair(funcType, func)); } CValue ASTConsCall::compile(CEnv& cenv) { vector params(size() - 1); for (size_t i = 1; i < size(); ++i) params[i-1] = LLVal(cenv.compile(at(i))); return cenv.engine.builder.CreateCall(LLFunc(funcs.find(functionType(cenv))), params.begin(), params.end()); } CValue ASTCarCall::compile(CEnv& cenv) { AST** arg = cenv.code.ref(at(1)); Value* sP = LLVal(arg ? (*arg)->compile(cenv) : at(1)->compile(cenv)); Value* s = cenv.engine.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* carP = cenv.engine.builder.CreateStructGEP(s, 0, "car"); return cenv.engine.builder.CreateLoad(carP); } CValue ASTCdrCall::compile(CEnv& cenv) { AST** arg = cenv.code.ref(at(1)); Value* sP = LLVal(arg ? (*arg)->compile(cenv) : at(1)->compile(cenv)); Value* s = cenv.engine.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* cdrP = cenv.engine.builder.CreateStructGEP(s, 1, "cdr"); return cenv.engine.builder.CreateLoad(cdrP); } /*************************************************************************** * EVAL/REPL/MAIN * ***************************************************************************/ const string call(AType* retT, void* fp) { std::stringstream ss; if (retT->type() == Type::Int32Ty) ss << ((int32_t (*)())fp)(); else if (retT->type() == Type::FloatTy) ss << ((float (*)())fp)(); else if (retT->type() == Type::Int1Ty) ss << ((bool (*)())fp)(); else ss << ((void* (*)())fp)(); return ss.str(); } int eval(CEnv& cenv, const string& name, istream& is) { AST* result = NULL; AType* resultType = NULL; list< pair > exprs; Cursor cursor(name); try { while (true) { SExp exp = readExpression(cursor, is); if (exp.type == SExp::LIST && exp.list.empty()) break; result = parseExpression(cenv.penv, exp); // Parse input result->constrain(cenv.tenv); // Constrain types cenv.tenv.solve(); // Solve and apply type constraints resultType = cenv.tenv.type(result); result->lift(cenv); // Lift functions exprs.push_back(make_pair(exp, result)); } if (!resultType || resultType->var()) throw Error("body is undefined/untyped", cursor); CType ctype = resultType->type(); if (!ctype) throw Error("body has no system type", cursor); // Create function for top-level of program Function* f = compileFunction(cenv, "main", ctype, ASTTuple()); // Compile all expressions into it Value* val = NULL; for (list< pair >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) val = LLVal(cenv.compile(i->second)); // Finish function cenv.engine.builder.CreateRet(val); cenv.optimise(f); out << call(resultType, cenv.engine.engine->getPointerToFunction(f)) << " : " << resultType->str() << endl; } catch (Error& e) { err << e.what() << endl; return 1; } return 0; } int repl(CEnv& cenv) { while (1) { out << "() "; out.flush(); Cursor cursor("(stdin)"); try { SExp exp = readExpression(cursor, std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; AST* body = parseExpression(cenv.penv, exp); // Parse input body->constrain(cenv.tenv); // Constrain types cenv.tenv.solve(); // Solve and apply type constraints AType* bodyT = cenv.tenv.type(body); if (!bodyT) throw Error("call to untyped body", cursor); body->lift(cenv); if (bodyT->type()) { // Create anonymous function to insert code into Function* f = compileFunction(cenv, cenv.gensym("_repl"), bodyT->type(), ASTTuple()); try { Value* retVal = LLVal(cenv.compile(body)); cenv.engine.builder.CreateRet(retVal); // Finish function cenv.optimise(f); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function throw e; } out << call(bodyT, cenv.engine.engine->getPointerToFunction(f)); } else { out << "; " << cenv.compile(body); } out << " : " << cenv.tenv.type(body)->str() << endl; } catch (Error& e) { err << e.what() << endl; } } return 0; } void initLang(PEnv& penv, TEnv& tenv) { penv.reg(true, "fn", PEnv::Handler(parseFn)); penv.reg(true, "if", PEnv::Handler(parseCall)); penv.reg(true, "def", PEnv::Handler(parseCall)); penv.reg(true, "cons", PEnv::Handler(parseCall)); penv.reg(true, "car", PEnv::Handler(parseCall)); penv.reg(true, "cdr", PEnv::Handler(parseCall)); bool trueVal = true; bool falseVal = false; penv.reg(false, "true", PEnv::Handler(parseLiteral, (CArg*)&trueVal)); penv.reg(false, "false", PEnv::Handler(parseLiteral, (CArg*)&falseVal)); map* prims = new map(); prims->insert(make_pair("+", CArg(Instruction::Add))); prims->insert(make_pair("-", CArg(Instruction::Sub))); prims->insert(make_pair("*", CArg(Instruction::Mul))); prims->insert(make_pair("/", CArg(Instruction::FDiv))); prims->insert(make_pair("%", CArg(Instruction::FRem))); prims->insert(make_pair("&", CArg(Instruction::And))); prims->insert(make_pair("|", CArg(Instruction::Or))); prims->insert(make_pair("^", CArg(Instruction::Xor))); prims->insert(make_pair("=", CArg(Instruction::ICmp, CmpInst::ICMP_EQ))); prims->insert(make_pair("!=", CArg(Instruction::ICmp, CmpInst::ICMP_NE))); prims->insert(make_pair(">", CArg(Instruction::ICmp, CmpInst::ICMP_SGT))); prims->insert(make_pair(">=", CArg(Instruction::ICmp, CmpInst::ICMP_SGE))); prims->insert(make_pair("<", CArg(Instruction::ICmp, CmpInst::ICMP_SLT))); prims->insert(make_pair("<=", CArg(Instruction::ICmp, CmpInst::ICMP_SLE))); for (map::iterator p = prims->begin(); p != prims->end(); ++p) penv.reg(true, p->first, PEnv::Handler(parseCall, &p->second)); tenv.def(penv.sym("Bool"), new AType(penv.sym("Bool"), Type::Int1Ty)); tenv.def(penv.sym("Int"), new AType(penv.sym("Int"), Type::Int32Ty)); tenv.def(penv.sym("Float"), new AType(penv.sym("Float"), Type::FloatTy)); } CEnv* newCenv(PEnv& penv, TEnv& tenv) { CEngine* engine = new CEngine(); CEnv* cenv = new CEnv(penv, tenv, *engine); // Host provided allocation primitive prototypes std::vector argsT(1, Type::Int32Ty); FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false); cenv->alloc = Function::Create(funcT, Function::ExternalLinkage, "malloc", engine->module); return cenv; }