diff options
-rw-r--r-- | tuplr.cpp | 288 |
1 files changed, 224 insertions, 64 deletions
@@ -16,6 +16,7 @@ */ #include <stdarg.h> +#include <fstream> #include <iostream> #include <list> #include <map> @@ -248,9 +249,23 @@ private: const Type* ctype; }; +/// Possibly several lifted LLVM functions for a single Tuplr function +struct Funcs : public list< pair<AType*, Function*> > { + Function* find(AType* type) const { + for (const_iterator f = begin(); f != end(); ++f) + if (*f->first == *type) + return f->second; + return NULL; + } + void insert(AType* type, Function* func) { + push_back(make_pair(type, func)); + } +}; + /// Closure (first-class function with captured lexical bindings) struct ASTClosure : public ASTTuple { - ASTClosure(ASTTuple* p, AST* b) : ASTTuple(0, p, b), prot(p), func(0) {} + ASTClosure(ASTTuple* p, AST* b, const string& n="") + : ASTTuple(0, p, b), prot(p), func(0), name(n) {} bool operator==(const AST& rhs) const { return this == &rhs; } string str() const { ostringstream s; s << this; return s.str(); } void constrain(TEnv& tenv) const; @@ -259,6 +274,7 @@ struct ASTClosure : public ASTTuple { ASTTuple* const prot; private: Function* func; + string name; }; /// Function call/application, e.g. "(func arg1 arg2)" @@ -295,11 +311,17 @@ struct ASTPrimitive : public ASTCall { /// Cons special form, e.g. "(cons 1 2)" struct ASTConsCall : public ASTCall { - ASTConsCall(const ASTTuple& t) : ASTCall(t) {} + ASTConsCall(const ASTTuple& t) : ASTCall(t), val(NULL) {} + AType* functionType(CEnv& cenv); void constrain(TEnv& tenv) const; + void lift(CEnv& cenv); Value* compile(CEnv& cenv); + Value* val; + static Funcs funcs; }; +Funcs ASTConsCall::funcs; + /// Car special form, e.g. "(car p)" struct ASTCarCall : public ASTCall { ASTCarCall(const ASTTuple& t) : ASTCall(t) {} @@ -712,7 +734,7 @@ struct CEnv { assert(!vals.ref(obj)); vals.def(obj, value); } - void optimise(Function& f) { verifyFunction(f); opt.run(f); } + void optimise(Function& f) { return; verifyFunction(f); opt.run(f); } typedef Env<const AST*, AST*> Code; typedef Env<const AST*, Value*> Vals; PEnv& penv; @@ -724,6 +746,7 @@ struct CEnv { unsigned symID; Code code; Vals vals; + Function* alloc; }; #define LITERAL(CT, NAME, COMPILED) \ @@ -738,7 +761,8 @@ LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)); LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); static Function* -compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT) +compileFunction(CEnv& cenv, const std::string& name, const Type* retT, ASTTuple& prot, + const vector<string> argNames=vector<string>()) { Function::LinkageTypes linkage = Function::ExternalLinkage; @@ -760,8 +784,12 @@ compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); - for (size_t i = 0; i != prot.size(); ++a, ++i) - a->setName(prot.at(i)->str()); + if (!argNames.empty()) + for (size_t i = 0; i != prot.size(); ++a, ++i) + a->setName(argNames.at(i)); + else + for (size_t i = 0; i != prot.size(); ++a, ++i) + a->setName(prot.at(i)->str()); return f; } @@ -787,7 +815,8 @@ ASTClosure::lift(CEnv& cenv) cenv.push(); // Write function declaration - Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(at(2))->type()); + string name = this->name == "" ? cenv.gensym("_fn") : this->name; + Function* f = compileFunction(cenv, name, cenv.tenv.type(at(2))->type(), *prot); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); @@ -964,21 +993,74 @@ ASTPrimitive::compile(CEnv& cenv) throw CompileError("Unknown primitive"); } -Value* -ASTConsCall::compile(CEnv& cenv) +AType* +ASTConsCall::functionType(CEnv& cenv) { + ASTTuple* protTypes = new ASTTuple(cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), NULL); + AType* cellType = new AType(ASTTuple(cenv.penv.sym("Pair"), + cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), NULL)); + return new AType(ASTTuple(cenv.penv.sym("Fn"), protTypes, cellType, NULL)); +} + +void +ASTConsCall::lift(CEnv& cenv) +{ + AType* funcType = functionType(cenv); + if (funcs.find(functionType(cenv))) + return; + + ASTCall::lift(cenv); + + ASTTuple* prot = new ASTTuple(at(1), at(2), NULL); + vector<const Type*> types; - for (size_t i = 1; i < size(); ++i) - types.push_back(cenv.tenv.type(at(i))->type()); + size_t sz = 0; + for (size_t i = 1; i < size(); ++i) { + const Type* t = cenv.tenv.type(at(i))->type(); + types.push_back(t); + sz += t->getPrimitiveSizeInBits(); + } + sz = (sz % 8 == 0) ? sz / 8 : sz / 8 + 1; + + StructType* sT = StructType::get(types, false); + Type* pT = PointerType::get(sT, 0); + + // Write function declaration + vector<string> argNames; + argNames.push_back("car"); + argNames.push_back("cdr"); + Function* func = compileFunction(cenv, cenv.gensym("cons"), pT, *prot, argNames); + BasicBlock* bb = BasicBlock::Create("entry", func); + cenv.builder.SetInsertPoint(bb); - StructType* t = StructType::get(types, false); - Value* sP = cenv.builder.CreateMalloc(t); - Value* s = cenv.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair"); + Value* mem = cenv.builder.CreateCall(cenv.alloc, ConstantInt::get(Type::Int32Ty, sz), "mem"); + Value* cell = cenv.builder.CreateBitCast(mem, pT, "cell"); + Value* s = cenv.builder.CreateGEP(cell, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* carP = cenv.builder.CreateStructGEP(s, 0, "car"); Value* cdrP = cenv.builder.CreateStructGEP(s, 1, "cdr"); - cenv.builder.CreateStore(cenv.compile(at(1)), carP); - cenv.builder.CreateStore(cenv.compile(at(2)), cdrP); - return sP; + Function::arg_iterator ai = func->arg_begin(); + Value& carArg = *ai++; + Value& cdrArg = *ai++; + cenv.builder.CreateStore(&carArg, carP); + cenv.builder.CreateStore(&cdrArg, cdrP); + cenv.builder.CreateRet(cell); + cenv.optimise(*func); + + funcs.insert(funcType, func); +} + +Value* +ASTConsCall::compile(CEnv& cenv) +{ + if (val != NULL) + return val; + + vector<Value*> params(size() - 1); + for (size_t i = 1; i < size(); ++i) + params[i-1] = cenv.compile(at(i)); + + val = cenv.builder.CreateCall(funcs.find(functionType(cenv)), params.begin(), params.end()); + return val; } Value* @@ -1003,45 +1085,75 @@ ASTCdrCall::compile(CEnv& cenv) /*************************************************************************** - * REPL * + * EVAL/REPL/MAIN * ***************************************************************************/ +std::string +call(AType* retT, void* fp) +{ + std::stringstream ss; + if (retT->type() == Type::Int32Ty) + ss << ((int32_t (*)())fp)(); + else if (retT->type() == Type::FloatTy) + ss << ((float (*)())fp)(); + else if (retT->type() == Type::Int1Ty) + ss << ((bool (*)())fp)(); + else + ss << ((void* (*)())fp)(); + return ss.str(); +} + int -main() +eval(CEnv& cenv, ExecutionEngine* engine, istream& is) { -#define PRIM(O, A) PEnv::Parser(parsePrim, Op(Instruction:: O, A)) - PEnv penv; - penv.reg("fn", PEnv::Parser(parseFn, Op())); - penv.reg("if", PEnv::Parser(parseIf, Op())); - penv.reg("def", PEnv::Parser(parseDef, Op())); - penv.reg("cons", PEnv::Parser(parseCons, Op())); - penv.reg("car", PEnv::Parser(parseCar, Op())); - penv.reg("cdr", PEnv::Parser(parseCdr, Op())); - penv.reg("+", PRIM(Add, 0)); - penv.reg("-", PRIM(Sub, 0)); - penv.reg("*", PRIM(Mul, 0)); - penv.reg("/", PRIM(FDiv, 0)); - penv.reg("%", PRIM(FRem, 0)); - penv.reg("&", PRIM(And, 0)); - penv.reg("|", PRIM(Or, 0)); - penv.reg("^", PRIM(Xor, 0)); - penv.reg("=", PRIM(ICmp, CmpInst::ICMP_EQ)); - penv.reg("!=", PRIM(ICmp, CmpInst::ICMP_NE)); - penv.reg(">", PRIM(ICmp, CmpInst::ICMP_SGT)); - penv.reg(">=", PRIM(ICmp, CmpInst::ICMP_SGE)); - penv.reg("<", PRIM(ICmp, CmpInst::ICMP_SLT)); - penv.reg("<=", PRIM(ICmp, CmpInst::ICMP_SLE)); + AST* result = NULL; + AType* resultType = NULL; + list< pair<SExp, AST*> > exprs; + while (true) { + SExp exp = readExpression(is); + if (exp.type == SExp::LIST && exp.list.empty()) + break; - Module* module = new Module("repl"); - ExecutionEngine* engine = ExecutionEngine::create(module); - CEnv cenv(penv, module, engine->getTargetData()); + try { + result = parseExpression(cenv.penv, exp); // Parse input + result->constrain(cenv.tenv); // Constrain types + cenv.tenv.solve(); // Solve and apply type constraints + resultType = cenv.tenv.type(result); + if (!resultType) throw TypeError("Call to untyped body"); + result->lift(cenv); // Lift functions + exprs.push_back(make_pair(exp, result)); + } catch (Error e) { + std::cerr << "Error: " << e.what() << endl; + } + } + + if (resultType->var) throw TypeError("Call to variable typed body"); + + // Create function for top-level of program + ASTTuple prot; + const Type* ctype = resultType->type(); + Function* f = compileFunction(cenv, cenv.gensym("input"), ctype, prot); + BasicBlock* bb = BasicBlock::Create("entry", f); + cenv.builder.SetInsertPoint(bb); - cenv.tenv.name("Bool", Type::Int1Ty); - cenv.tenv.name("Int", Type::Int32Ty); - cenv.tenv.name("Float", Type::FloatTy); - cenv.code.def(penv.sym("true"), new ASTLiteral<bool>(true)); - cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false)); + // Compile all expressions into it + Value* val = NULL; + for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) + val = cenv.compile(i->second); + + // Finish function + cenv.builder.CreateRet(val); + cenv.optimise(*f); + + string resultStr = call(resultType, engine->getPointerToFunction(f)); + std::cout << resultStr << " : " << resultType->str() << endl; + + return 0; +} +int +repl(CEnv& cenv, ExecutionEngine* engine) +{ while (1) { std::cout << "() "; std::cout.flush(); @@ -1050,9 +1162,8 @@ main() break; try { - AST* body = parseExpression(penv, exp); // Parse input + AST* body = parseExpression(cenv.penv, exp); // Parse input body->constrain(cenv.tenv); // Constrain types - cenv.tenv.solve(); // Solve and apply type constraints AType* bodyT = cenv.tenv.type(body); @@ -1060,11 +1171,11 @@ main() if (bodyT->var) throw TypeError("REPL call to variable typed body"); body->lift(cenv); - + if (bodyT->type()) { // Create anonymous function to insert code into. ASTTuple* prot = new ASTTuple(); - Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->type()); + Function* f = compileFunction(cenv, cenv.gensym("_repl"), bodyT->type(), *prot); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); try { @@ -1075,15 +1186,7 @@ main() f->eraseFromParent(); // Error reading body, remove function throw e; } - void* fp = engine->getPointerToFunction(f); - if (bodyT->type() == Type::Int32Ty) - std::cout << "; " << ((int32_t (*)())fp)(); - else if (bodyT->type() == Type::FloatTy) - std::cout << "; " << ((float (*)())fp)(); - else if (bodyT->type() == Type::Int1Ty) - std::cout << "; " << ((bool (*)())fp)(); - else - std::cout << "; " << ((void* (*)())fp)(); + std::cout << call(bodyT, engine->getPointerToFunction(f)); } else { Value* val = cenv.compile(body); std::cout << "; " << val; @@ -1095,8 +1198,65 @@ main() } } - std::cout << endl << "Generated code:" << endl; - module->dump(); return 0; } +int +main(int argc, char** argv) +{ +#define PRIM(O, A) PEnv::Parser(parsePrim, Op(Instruction:: O, A)) + PEnv penv; + penv.reg("fn", PEnv::Parser(parseFn, Op())); + penv.reg("if", PEnv::Parser(parseIf, Op())); + penv.reg("def", PEnv::Parser(parseDef, Op())); + penv.reg("cons", PEnv::Parser(parseCons, Op())); + penv.reg("car", PEnv::Parser(parseCar, Op())); + penv.reg("cdr", PEnv::Parser(parseCdr, Op())); + penv.reg("+", PRIM(Add, 0)); + penv.reg("-", PRIM(Sub, 0)); + penv.reg("*", PRIM(Mul, 0)); + penv.reg("/", PRIM(FDiv, 0)); + penv.reg("%", PRIM(FRem, 0)); + penv.reg("&", PRIM(And, 0)); + penv.reg("|", PRIM(Or, 0)); + penv.reg("^", PRIM(Xor, 0)); + penv.reg("=", PRIM(ICmp, CmpInst::ICMP_EQ)); + penv.reg("!=", PRIM(ICmp, CmpInst::ICMP_NE)); + penv.reg(">", PRIM(ICmp, CmpInst::ICMP_SGT)); + penv.reg(">=", PRIM(ICmp, CmpInst::ICMP_SGE)); + penv.reg("<", PRIM(ICmp, CmpInst::ICMP_SLT)); + penv.reg("<=", PRIM(ICmp, CmpInst::ICMP_SLE)); + + Module* module = new Module("interactive"); + ExecutionEngine* engine = ExecutionEngine::create(module); + CEnv cenv(penv, module, engine->getTargetData()); + + cenv.tenv.name("Bool", Type::Int1Ty); + cenv.tenv.name("Int", Type::Int32Ty); + cenv.tenv.name("Float", Type::FloatTy); + cenv.code.def(penv.sym("true"), new ASTLiteral<bool>(true)); + cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false)); + + // Host provided allocation primitive prototypes + std::vector<const Type*> argsT(1, Type::Int32Ty); + FunctionType* funcT = FunctionType::get(PointerType::get(Type::VoidTy, 0), argsT, false); + cenv.alloc = Function::Create(funcT, Function::ExternalLinkage, "malloc", module); + + int ret; + if (argc > 2 && !strncmp(argv[1], "-e", 3)) { + std::istringstream is(argv[2]); + ret = eval(cenv, engine, is); + } else if (argc > 2 && !strncmp(argv[1], "-f", 3)) { + std::ifstream is(argv[2]); + ret = eval(cenv, engine, is); + is.close(); + } else { + ret = repl(cenv, engine); + } + + //std::cout << endl << "*** Generated Code ***" << endl; + //cenv.module->dump(); + + return ret; +} + |