diff options
Diffstat (limited to 'tuplr_llvm.cpp')
-rw-r--r-- | tuplr_llvm.cpp | 635 |
1 files changed, 635 insertions, 0 deletions
diff --git a/tuplr_llvm.cpp b/tuplr_llvm.cpp new file mode 100644 index 0000000..a807f75 --- /dev/null +++ b/tuplr_llvm.cpp @@ -0,0 +1,635 @@ +/* Tuplr: A programming language + * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net> + * + * Tuplr is free software: you can redistribute it and/or modify it under + * the terms of the GNU Affero General Public License as published by the + * Free Software Foundation, either version 3 of the License, or (at your + * option) any later version. + * + * Tuplr is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Tuplr. If not, see <http://www.gnu.org/licenses/>. + */ + +#include <sstream> +#include <fstream> +#include "tuplr.hpp" +#include "tuplr_llvm.hpp" + +using namespace llvm; +using namespace std; +using boost::format; + + +/*************************************************************************** + * Abstract Syntax Tree * + ***************************************************************************/ + +const CType* +AType::type() +{ + if (at(0)->str() == "Pair") { + vector<const CType*> types; + for (size_t i = 1; i < size(); ++i) { + assert(dynamic_cast<AType*>(at(i))); + types.push_back(((AType*)at(i))->type()); + } + return PointerType::get(StructType::get(types, false), 0); + } else { + return ctype; + } +} + + +/*************************************************************************** + * Typing * + ***************************************************************************/ + +#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) + +void +ASTPrimitive::constrain(TEnv& tenv) const +{ + FOREACH(const_iterator, p, *this) + (*p)->constrain(tenv); + if (OP_IS_A(arg.op, Instruction::BinaryOps)) { + if (size() <= 2) throw Error((format("`%1%' requires at least 2 arguments") + % at(0)->str()).str(), exp.loc); + AType* tvar = tenv.type(this); + for (size_t i = 1; i < size(); ++i) + tenv.constrain(at(i), tvar); + } else if (arg.op == Instruction::ICmp) { + if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") + % at(0)->str()).str(), exp.loc); + tenv.constrain(at(1), tenv.type(at(2))); + tenv.constrain(this, tenv.named("Bool")); + } else { + throw Error((format("unknown primitive `%1%'") % at(0)->str()).str(), exp.loc); + } +} + + +/*************************************************************************** + * Code Generation * + ***************************************************************************/ + +// Compile-Time Environment + +CEngine::CEngine() + : module(new Module("tuplr")) + , engine(ExecutionEngine::create(module)) +{ +} + +struct CEnvPimpl { + CEnvPimpl(CEngine& engine) + : module(engine.module), emp(module), opt(&emp) + { + // Set up the optimizer pipeline: + const TargetData* target = engine.engine->getTargetData(); + opt.add(new TargetData(*target)); // Register target arch + opt.add(createInstructionCombiningPass()); // Simple optimizations + opt.add(createReassociatePass()); // Reassociate expressions + opt.add(createGVNPass()); // Eliminate Common Subexpressions + opt.add(createCFGSimplificationPass()); // Simplify control flow + } + + Module* module; + ExistingModuleProvider emp; + FunctionPassManager opt; + Function* alloc; +}; + +CEnv::CEnv(PEnv& p, CEngine& eng) + : engine(eng), penv(p), tenv(p), symID(0), _pimpl(new CEnvPimpl(eng)) +{ +} + +CEnv::~CEnv() +{ + delete _pimpl; +} + +CValue* +CEnv::compile(AST* obj) +{ + CValue** v = vals.ref(obj); + return (v) ? *v : vals.def(obj, obj->compile(*this)); +} + +void +CEnv::optimise(Function& f) +{ + verifyFunction(f); + _pimpl->opt.run(f); +} + +#define LITERAL(CT, NAME, COMPILED) \ +template<> CValue* \ +ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \ +template<> void \ +ASTLiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } + +/// Literal template instantiations +LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)) +LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)) +LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)) + +static Function* +compileFunction(CEnv& cenv, const std::string& name, const CType* retT, const ASTTuple& prot, + const vector<string> argNames=vector<string>()) +{ + Function::LinkageTypes linkage = Function::ExternalLinkage; + + vector<const CType*> cprot; + for (size_t i = 0; i < prot.size(); ++i) { + AType* at = cenv.tenv.type(prot.at(i)); + if (!at->type() || at->var()) throw Error("function parameter is untyped"); + cprot.push_back(at->type()); + } + + if (!retT) throw Error("function return is untyped"); + FunctionType* fT = FunctionType::get(retT, cprot, false); + Function* f = Function::Create(fT, linkage, name, cenv.engine.module); + + if (f->getName() != name) { + f->eraseFromParent(); + throw Error("function redefined"); + } + + // Set argument names in generated code + Function::arg_iterator a = f->arg_begin(); + if (!argNames.empty()) + for (size_t i = 0; i != prot.size(); ++a, ++i) + a->setName(argNames.at(i)); + else + for (size_t i = 0; i != prot.size(); ++a, ++i) + a->setName(prot.at(i)->str()); + + BasicBlock* bb = BasicBlock::Create("entry", f); + cenv.engine.builder.SetInsertPoint(bb); + + return f; +} + +CValue* +ASTSymbol::compile(CEnv& cenv) +{ + AST** c = cenv.code.ref(this); + if (!c) throw Error((string("undefined symbol `") + cppstr + "'").c_str(), loc); + return cenv.compile(*c); +} + +void +ASTClosure::lift(CEnv& cenv) +{ + AType* type = cenv.tenv.type(this); + if (!type->concrete()) { + err << "closure is untyped, not lifting" << endl; + return; + } + + if (funcs.find(type)) + return; + + cenv.push(); + + // Write function declaration + string name = this->name == "" ? cenv.gensym("_fn") : this->name; + Function* f = compileFunction(cenv, name, cenv.tenv.type(at(2))->type(), *prot()); + + // Bind argument values in CEnv + vector<CValue*> args; + const_iterator p = prot()->begin(); + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) + cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a); + + // Write function body + try { + cenv.precompile(this, f); // Define our value first for recursion + CValue* retVal = cenv.compile(at(2)); + cenv.engine.builder.CreateRet(retVal); // Finish function + cenv.optimise(*f); + funcs.insert(type, f); + } catch (Error& e) { + f->eraseFromParent(); // Error reading body, remove function + throw e; + } + + cenv.pop(); +} + +CValue* +ASTClosure::compile(CEnv& cenv) +{ + return funcs.find(cenv.tenv.type(this)); +} + +void +ASTCall::lift(CEnv& cenv) +{ + ASTClosure* c = dynamic_cast<ASTClosure*>(at(0)); + if (!c) { + AST** val = cenv.code.ref(at(0)); + c = (val) ? dynamic_cast<ASTClosure*>(*val) : c; + } + + // Lift arguments + for (size_t i = 1; i < size(); ++i) + at(i)->lift(cenv); + + if (!c) return; + + // Extend environment with bound and typed parameters + cenv.push(); + if (c->prot()->size() < size() - 1) + throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), exp.loc); + if (c->prot()->size() > size() - 1) + throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), exp.loc); + + for (size_t i = 1; i < size(); ++i) + cenv.code.def(c->prot()->at(i-1), at(i)); + + c->lift(cenv); // Lift called closure + cenv.pop(); // Restore environment +} + +CValue* +ASTCall::compile(CEnv& cenv) +{ + ASTClosure* c = dynamic_cast<ASTClosure*>(at(0)); + if (!c) { + AST** val = cenv.code.ref(at(0)); + c = (val) ? dynamic_cast<ASTClosure*>(*val) : c; + } + + assert(c); + Function* f = dynamic_cast<Function*>(cenv.compile(c)); + if (!f) throw Error("callee failed to compile", exp.loc); + + vector<CValue*> params(size() - 1); + for (size_t i = 1; i < size(); ++i) + params[i-1] = cenv.compile(at(i)); + + return cenv.engine.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); +} + +void +ASTDefinition::lift(CEnv& cenv) +{ + if (cenv.code.ref((ASTSymbol*)at(1))) + throw Error(string("`") + at(1)->str() + "' redefined", exp.loc); + cenv.code.def((ASTSymbol*)at(1), at(2)); // Define first for recursion + at(2)->lift(cenv); +} + +CValue* +ASTDefinition::compile(CEnv& cenv) +{ + return cenv.compile(at(2)); +} + +CValue* +ASTIf::compile(CEnv& cenv) +{ + typedef vector< pair<CValue*, BasicBlock*> > Branches; + Function* parent = cenv.engine.builder.GetInsertBlock()->getParent(); + BasicBlock* mergeBB = BasicBlock::Create("endif"); + BasicBlock* nextBB = NULL; + Branches branches; + ostringstream ss; + for (size_t i = 1; i < size() - 1; i += 2) { + CValue* condV = cenv.compile(at(i)); + + ss.str(""); ss << "then" << ((i + 1) / 2); + BasicBlock* thenBB = BasicBlock::Create(ss.str()); + + ss.str(""); ss << "else" << ((i + 1) / 2); + nextBB = BasicBlock::Create(ss.str()); + + cenv.engine.builder.CreateCondBr(condV, thenBB, nextBB); + + // Emit then block for this condition + parent->getBasicBlockList().push_back(thenBB); + cenv.engine.builder.SetInsertPoint(thenBB); + CValue* thenV = cenv.compile(at(i + 1)); + cenv.engine.builder.CreateBr(mergeBB); + branches.push_back(make_pair(thenV, cenv.engine.builder.GetInsertBlock())); + + parent->getBasicBlockList().push_back(nextBB); + cenv.engine.builder.SetInsertPoint(nextBB); + } + + // Emit else block + cenv.engine.builder.SetInsertPoint(nextBB); + CValue* elseV = cenv.compile(at(size() - 1)); + cenv.engine.builder.CreateBr(mergeBB); + branches.push_back(make_pair(elseV, cenv.engine.builder.GetInsertBlock())); + + // Emit merge block (Phi node) + parent->getBasicBlockList().push_back(mergeBB); + cenv.engine.builder.SetInsertPoint(mergeBB); + PHINode* pn = cenv.engine.builder.CreatePHI(cenv.tenv.type(this)->type(), "ifval"); + + for (Branches::iterator i = branches.begin(); i != branches.end(); ++i) + pn->addIncoming(i->first, i->second); + + return pn; +} + +CValue* +ASTPrimitive::compile(CEnv& cenv) +{ + CValue* a = cenv.compile(at(1)); + CValue* b = cenv.compile(at(2)); + + if (OP_IS_A(arg.op, Instruction::BinaryOps)) { + const Instruction::BinaryOps bo = (Instruction::BinaryOps)arg.op; + if (size() == 2) + return cenv.compile(at(1)); + CValue* val = cenv.engine.builder.CreateBinOp(bo, a, b); + for (size_t i = 3; i < size(); ++i) + val = cenv.engine.builder.CreateBinOp(bo, val, cenv.compile(at(i))); + return val; + } else if (arg.op == Instruction::ICmp) { + bool isInt = cenv.tenv.type(at(1))->str() == "Int"; + if (isInt) { + return cenv.engine.builder.CreateICmp((CmpInst::Predicate)arg.arg, a, b); + } else { + // Translate to floating point operation + switch (arg.arg) { + case CmpInst::ICMP_EQ: arg.arg = CmpInst::FCMP_OEQ; break; + case CmpInst::ICMP_NE: arg.arg = CmpInst::FCMP_ONE; break; + case CmpInst::ICMP_SGT: arg.arg = CmpInst::FCMP_OGT; break; + case CmpInst::ICMP_SGE: arg.arg = CmpInst::FCMP_OGE; break; + case CmpInst::ICMP_SLT: arg.arg = CmpInst::FCMP_OLT; break; + case CmpInst::ICMP_SLE: arg.arg = CmpInst::FCMP_OLE; break; + default: throw Error("Unknown primitive", exp.loc); + } + return cenv.engine.builder.CreateFCmp((CmpInst::Predicate)arg.arg, a, b); + } + } + throw Error("Unknown primitive", exp.loc); +} + +AType* +ASTConsCall::functionType(CEnv& cenv) +{ + ASTTuple* protTypes = new ASTTuple(cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), NULL); + AType* cellType = new AType(ASTTuple(cenv.penv.sym("Pair"), + cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), NULL)); + return new AType(ASTTuple(cenv.penv.sym("Fn"), protTypes, cellType, NULL)); +} + +void +ASTConsCall::lift(CEnv& cenv) +{ + AType* funcType = functionType(cenv); + if (funcs.find(functionType(cenv))) + return; + + ASTCall::lift(cenv); + + ASTTuple* prot = new ASTTuple(at(1), at(2), NULL); + + vector<const CType*> types; + size_t sz = 0; + for (size_t i = 1; i < size(); ++i) { + const CType* t = cenv.tenv.type(at(i))->type(); + types.push_back(t); + sz += t->getPrimitiveSizeInBits(); + } + sz = (sz % 8 == 0) ? sz / 8 : sz / 8 + 1; + + StructType* sT = StructType::get(types, false); + CType* pT = PointerType::get(sT, 0); + + // Write function declaration + vector<string> argNames; + argNames.push_back("car"); + argNames.push_back("cdr"); + Function* func = compileFunction(cenv, cenv.gensym("cons"), pT, *prot, argNames); + + CValue* mem = cenv.engine.builder.CreateCall(cenv.alloc, ConstantInt::get(Type::Int32Ty, sz), "mem"); + CValue* cell = cenv.engine.builder.CreateBitCast(mem, pT, "cell"); + CValue* s = cenv.engine.builder.CreateGEP(cell, ConstantInt::get(Type::Int32Ty, 0), "pair"); + CValue* carP = cenv.engine.builder.CreateStructGEP(s, 0, "car"); + CValue* cdrP = cenv.engine.builder.CreateStructGEP(s, 1, "cdr"); + Function::arg_iterator ai = func->arg_begin(); + Value& carArg = *ai++; + Value& cdrArg = *ai++; + cenv.engine.builder.CreateStore(&carArg, carP); + cenv.engine.builder.CreateStore(&cdrArg, cdrP); + cenv.engine.builder.CreateRet(cell); + cenv.optimise(*func); + + funcs.insert(funcType, func); +} + +CValue* +ASTConsCall::compile(CEnv& cenv) +{ + vector<CValue*> params(size() - 1); + for (size_t i = 1; i < size(); ++i) + params[i-1] = cenv.compile(at(i)); + + return cenv.engine.builder.CreateCall(funcs.find(functionType(cenv)), params.begin(), params.end()); +} + +CValue* +ASTCarCall::compile(CEnv& cenv) +{ + AST** arg = cenv.code.ref(at(1)); + CValue* sP = arg ? (*arg)->compile(cenv) : at(1)->compile(cenv); + CValue* s = cenv.engine.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair"); + CValue* carP = cenv.engine.builder.CreateStructGEP(s, 0, "car"); + return cenv.engine.builder.CreateLoad(carP); +} + +CValue* +ASTCdrCall::compile(CEnv& cenv) +{ + AST** arg = cenv.code.ref(at(1)); + CValue* sP = arg ? (*arg)->compile(cenv) : at(1)->compile(cenv); + CValue* s = cenv.engine.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair"); + CValue* cdrP = cenv.engine.builder.CreateStructGEP(s, 1, "cdr"); + return cenv.engine.builder.CreateLoad(cdrP); +} + + +/*************************************************************************** + * EVAL/REPL/MAIN * + ***************************************************************************/ + +std::string +call(AType* retT, void* fp) +{ + std::stringstream ss; + if (retT->type() == Type::Int32Ty) + ss << ((int32_t (*)())fp)(); + else if (retT->type() == Type::FloatTy) + ss << ((float (*)())fp)(); + else if (retT->type() == Type::Int1Ty) + ss << ((bool (*)())fp)(); + else + ss << ((void* (*)())fp)(); + return ss.str(); +} + +int +eval(CEnv& cenv, const string& name, istream& is) +{ + AST* result = NULL; + AType* resultType = NULL; + list< pair<SExp, AST*> > exprs; + Cursor cursor(name); + try { + while (true) { + SExp exp = readExpression(cursor, is); + if (exp.type == SExp::LIST && exp.list.empty()) + break; + + result = parseExpression(cenv.penv, exp); // Parse input + result->constrain(cenv.tenv); // Constrain types + cenv.tenv.solve(); // Solve and apply type constraints + resultType = cenv.tenv.type(result); + result->lift(cenv); // Lift functions + exprs.push_back(make_pair(exp, result)); + } + + if (!resultType || resultType->var()) throw Error("body is undefined/untyped", cursor); + + const CType* ctype = resultType->type(); + if (!ctype) throw Error("body has no system type", cursor); + + // Create function for top-level of program + Function* f = compileFunction(cenv, cenv.gensym("input"), ctype, ASTTuple()); + + // Compile all expressions into it + CValue* val = NULL; + for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) + val = cenv.compile(i->second); + + // Finish function + cenv.engine.builder.CreateRet(val); + cenv.optimise(*f); + + string resultStr = call(resultType, cenv.engine.engine->getPointerToFunction(f)); + out << resultStr << " : " << resultType->str() << endl; + + } catch (Error& e) { + err << e.what() << endl; + return 1; + } + + return 0; +} + +int +repl(CEnv& cenv) +{ + while (1) { + out << "() "; + out.flush(); + Cursor cursor("(stdin)"); + SExp exp = readExpression(cursor, std::cin); + if (exp.type == SExp::LIST && exp.list.empty()) + break; + + try { + AST* body = parseExpression(cenv.penv, exp); // Parse input + body->constrain(cenv.tenv); // Constrain types + cenv.tenv.solve(); // Solve and apply type constraints + + AType* bodyT = cenv.tenv.type(body); + if (!bodyT) throw Error("call to untyped body", cursor); + if (!bodyT->concrete()) throw Error("call to variable typed body", cursor); + + body->lift(cenv); + + if (bodyT->type()) { + // Create anonymous function to insert code into + Function* f = compileFunction(cenv, cenv.gensym("_repl"), bodyT->type(), ASTTuple()); + try { + CValue* retVal = cenv.compile(body); + cenv.engine.builder.CreateRet(retVal); // Finish function + cenv.optimise(*f); + } catch (Error& e) { + f->eraseFromParent(); // Error reading body, remove function + throw e; + } + out << call(bodyT, cenv.engine.engine->getPointerToFunction(f)); + } else { + CValue* val = cenv.compile(body); + out << "; " << val; + } + out << " : " << cenv.tenv.type(body)->str() << endl; + + } catch (Error& e) { + err << e.what() << endl; + } + } + + return 0; +} + +int +main(int argc, char** argv) +{ +#define PRIM(O, A) PEnv::Parser(parseAST<ASTPrimitive>, CArg(Instruction:: O, A)) + PEnv penv; + penv.reg("fn", PEnv::Parser(parseFn)); + penv.reg("if", PEnv::Parser(parseAST<ASTIf>)); + penv.reg("def", PEnv::Parser(parseAST<ASTDefinition>)); + penv.reg("cons", PEnv::Parser(parseAST<ASTConsCall>)); + penv.reg("car", PEnv::Parser(parseAST<ASTCarCall>)); + penv.reg("cdr", PEnv::Parser(parseAST<ASTCdrCall>)); + penv.reg("+", PRIM(Add, 0)); + penv.reg("-", PRIM(Sub, 0)); + penv.reg("*", PRIM(Mul, 0)); + penv.reg("/", PRIM(FDiv, 0)); + penv.reg("%", PRIM(FRem, 0)); + penv.reg("&", PRIM(And, 0)); + penv.reg("|", PRIM(Or, 0)); + penv.reg("^", PRIM(Xor, 0)); + penv.reg("=", PRIM(ICmp, CmpInst::ICMP_EQ)); + penv.reg("!=", PRIM(ICmp, CmpInst::ICMP_NE)); + penv.reg(">", PRIM(ICmp, CmpInst::ICMP_SGT)); + penv.reg(">=", PRIM(ICmp, CmpInst::ICMP_SGE)); + penv.reg("<", PRIM(ICmp, CmpInst::ICMP_SLT)); + penv.reg("<=", PRIM(ICmp, CmpInst::ICMP_SLE)); + + CEngine engine; + CEnv cenv(penv, engine); + + cenv.tenv.def(penv.sym("Bool"), new AType(penv.sym("Bool"), Type::Int1Ty)); + cenv.tenv.def(penv.sym("Int"), new AType(penv.sym("Int"), Type::Int32Ty)); + cenv.tenv.def(penv.sym("Float"), new AType(penv.sym("Float"), Type::FloatTy)); + + // Host provided allocation primitive prototypes + std::vector<const CType*> argsT(1, Type::Int32Ty); + FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false); + cenv.alloc = Function::Create(funcT, Function::ExternalLinkage, "malloc", engine.module); + + int ret; + if (argc > 2 && !strncmp(argv[1], "-e", 3)) { + std::istringstream is(argv[2]); + ret = eval(cenv, "(command line)", is); + } else if (argc > 2 && !strncmp(argv[1], "-f", 3)) { + std::ifstream is(argv[2]); + ret = eval(cenv, argv[2], is); + is.close(); + } else { + ret = repl(cenv); + } + + //out << endl << "*** Generated Code ***" << endl; + //cenv.module->dump(); + + return ret; +} + |