diff options
-rw-r--r-- | ll.cpp | 665 |
1 files changed, 357 insertions, 308 deletions
@@ -123,9 +123,82 @@ readExpression(std::istream& in) /*************************************************************************** + * Environment * + ***************************************************************************/ + +class AST; +class ASTSymbol; + +/// Generic Recursive Environment (stack of key:value dictionaries) +template<typename K, typename V> +struct Env : public list< map<K,V> > { + Env() : list< map<K, V> >(1) {} + void push() { this->push_front(map<K,V>()); } + void push(const map<K,V>& frame) { this->push_front(frame); } + map<K,V>& pop() { + map<K,V>& front = this->front(); + this->pop_front(); + return front; + } + void def(const K& k, const V& v) { + if (this->front().find(k) != this->front().end()) + throw SyntaxError("Redefinition"); + this->front()[k] = v; + } + V* ref(const K& name) { + typename Env::iterator i = this->begin(); + for (; i != this->end(); ++i) { + typename map<K,V>::iterator s = i->find(name); + if (s != i->end()) + return &s->second; + } + return 0; + } +}; + +class PEnv; + +/// Compile-time environment +struct CEnv { + CEnv(PEnv& p, Module* m, const TargetData* target) + : penv(p), module(m), emp(module), fpm(&emp), symID(0), tID(0) + { + // Set up the optimizer pipeline. + // Register info about how the target lays out data structures. + fpm.add(new TargetData(*target)); + // Do simple "peephole" and bit-twiddling optimizations. + fpm.add(createInstructionCombiningPass()); + // Reassociate expressions. + fpm.add(createReassociatePass()); + // Eliminate Common SubExpressions. + fpm.add(createGVNPass()); + // Simplify control flow graph (delete unreachable blocks, etc). + fpm.add(createCFGSimplificationPass()); + } + string gensym(const char* base="_") { + ostringstream s; s << base << symID++; return s.str(); + } + typedef Env<const AST*, AST*> Code; + typedef Env<const ASTSymbol*, Value*> Vals; + + PEnv& penv; + IRBuilder<> builder; + Module* module; + ExistingModuleProvider emp; + FunctionPassManager fpm; + unsigned symID; + unsigned tID; + Code code; + Vals vals; +}; + + + +/*************************************************************************** * Abstract Syntax Tree * ***************************************************************************/ +struct AType; struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; struct CEnv; ///< Compile Time Environment @@ -133,148 +206,187 @@ struct CEnv; ///< Compile Time Environment /// Base class for all AST nodes struct AST { virtual ~AST() {} - virtual const Type* type(CEnv& cenv) const = 0; - virtual Value* compile(CEnv& cenv) = 0; + virtual string str(CEnv& cenv) const = 0; + virtual AType* type(CEnv& cenv) = 0; + virtual Value* compile(CEnv& cenv) = 0; + virtual void lift(CEnv& cenv) {} +}; + +/// Symbol, e.g. "a" +struct ASTSymbol : public AST { + ASTSymbol(const string& s) : cppstr(s) {} + std::string str(CEnv&) const { return cppstr; } + AType* type(CEnv& cenv); + Value* compile(CEnv& cenv); +private: + const string cppstr; +}; + +/// Tuple (heterogeneous sequence of known length), e.g. "(a b c)" +struct ASTTuple : public AST { + ASTTuple(vector<AST*> t=vector<AST*>()) : tup(t) {} + string str(CEnv& cenv) const { + string ret = "("; + for (size_t i = 0; i != tup.size(); ++i) + ret += tup[i]->str(cenv) + ((i != tup.size() - 1) ? " " : ""); + ret.append(")"); + return ret; + } + void lift(CEnv& cenv) { + FOREACH(vector<AST*>::iterator, t, tup) + (*t)->lift(cenv); + } + AType* type(CEnv& cenv); + Value* compile(CEnv& cenv) { return NULL; } + vector<AST*> tup; +}; + +/// TExpr ::= (TName TExpr*) | ?Num +struct AType : public ASTTuple { + AType(unsigned i) : var(true), ctype(0), id(id) {} + AType(const string& n, const Type* t) : var(false), ctype(t) { + tup.push_back(new ASTSymbol(n)); + } + AType(const vector<AST*>& t) : ASTTuple(t), var(false), ctype(0) {} + inline bool operator==(const AType& t) const { return tup[0] == t.tup[0]; } + inline bool operator!=(const AType& t) const { return tup[0] != t.tup[0]; } + string str(CEnv& cenv) const { return var ? "?" : ASTTuple::str(cenv); } + AType* type(CEnv& cenv) { return this; } + Value* compile(CEnv& cenv) { return NULL; } + bool var; + const Type* ctype; + unsigned id; }; /// Literal template<typename VT> struct ASTLiteral : public AST { ASTLiteral(VT v) : val(v) {} - const Type* type(CEnv& cenv) const; - Value* compile(CEnv& cenv); + string str(CEnv& env) const { return "(Literal)"; } + AType* type(CEnv& cenv); + Value* compile(CEnv& cenv); const VT val; }; - -#define LITERAL(CT, VT, COMPILED) \ -template<> const Type* \ -ASTLiteral<CT>::type(CEnv& cenv) const { return VT; } \ - \ +#define LITERAL(CT, VT, NAME, COMPILED) \ +template<> string \ +ASTLiteral<CT>::str(CEnv& cenv) const { return NAME; } \ +template<> AType* \ +ASTLiteral<CT>::type(CEnv& cenv) { return new AType(NAME, VT); } \ template<> Value* \ ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } -/// Literal template specialisations -LITERAL(int32_t, Type::Int32Ty, ConstantInt::get(type(cenv), val, true)); -LITERAL(float, Type::FloatTy, ConstantFP::get(type(cenv), val)); -LITERAL(bool, Type::Int1Ty, ConstantInt::get(type(cenv), val, false)); +/// Literal template instantiations +LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, true)); +LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val)); +LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); -/// Symbol, e.g. "a" -struct ASTSymbol : public AST { - ASTSymbol(const string& n) : name(n) {} - virtual const Type* type(CEnv& cenv) const; - virtual Value* compile(CEnv& cenv); - const string name; +typedef unsigned UD; // User Data passed to registered parse functions + +// Parse Time Environment (symbol table) +struct PEnv : private map<const string, ASTSymbol*> { + typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function + struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; + map<const ASTSymbol*, Parser> parsers; + void reg(const ASTSymbol* s, const Parser& p) { + parsers.insert(make_pair(s, p)); + } + const Parser* parser(const ASTSymbol* s) const { + map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s); + return (i != parsers.end()) ? &i->second : NULL; + } + ASTSymbol* sym(const string& s) { + const const_iterator i = find(s); + return ((i != end()) + ? i->second + : insert(make_pair(s, new ASTSymbol(s))).first->second); + } +}; + +/// Closure (first-class function with captured lexical bindings) +struct ASTClosure : public AST { + ASTClosure(ASTTuple* p, AST* b) : prot(p), body(b), func(0) {} + string str(CEnv& env) const { return "(fn)"; } + AType* type(CEnv& cenv) { + vector<AST*> texp(3); + texp[0] = cenv.penv.sym("Fn"); + texp[1] = prot; + texp[2] = body; + return new AType(texp); + } + Value* compile(CEnv& cenv); + void lift(CEnv& cenv); + ASTTuple* const prot; + AST* const body; + vector<const ASTSymbol*> bindings; +private: + Function* func; }; /// Function call/application, e.g. "(func arg1 arg2)" -struct ASTCall : public AST { - ASTCall(const vector<AST*>& c) : code(c) {} - virtual const Type* type(CEnv& cenv) const { - AST* func = code[0]; - const FunctionType* ftype = dynamic_cast<const FunctionType*>(func->type(cenv)); - if (!ftype) throw TypeError(string("Call to non-function type :: ") - .append(func->type(cenv)->getDescription()).c_str()); - return ftype->getReturnType(); +struct ASTCall : public ASTTuple { + ASTCall(const vector<AST*>& t) : ASTTuple(t) {} + AType* type(CEnv& cenv) { + AST* callee = tup[0]; + ASTSymbol* sym = dynamic_cast<ASTSymbol*>(tup[0]); + if (sym) { + AST** val = cenv.code.ref(sym); + if (val) + callee = *val; + } + ASTClosure* c = dynamic_cast<ASTClosure*>(callee); + if (!c) throw TypeError("Call to non-closure"); + return c->body->type(cenv); } - virtual Value* compile(CEnv& cenv); - const vector<AST*> code; + void lift(CEnv& cenv); + Value* compile(CEnv& cenv); }; /// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" struct ASTDefinition : public ASTCall { ASTDefinition(const vector<AST*>& c) : ASTCall(c) {} - virtual const Type* type(CEnv& cenv) const { return code[2]->type(cenv); } - virtual Value* compile(CEnv& cenv); + AType* type(CEnv& cenv) { return tup[2]->type(cenv); } + Value* compile(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { ASTIf(const vector<AST*>& c) : ASTCall(c) {} - virtual const Type* type(CEnv& cenv) const { - const Type* cT = code[1]->type(cenv); - const Type* tT = code[2]->type(cenv); - const Type* eT = code[3]->type(cenv); - if (cT != Type::Int1Ty) throw TypeError("If condition is not a boolean"); - if (tT != eT) throw TypeError("If branches have different types"); + AType* type(CEnv& cenv) { + AType* cT = tup[1]->type(cenv); + AType* tT = tup[2]->type(cenv); + AType* eT = tup[3]->type(cenv); + if (cT->ctype != Type::Int1Ty) throw TypeError("If condition is not a boolean"); + if (*tT != *eT) throw TypeError("If branches have different types"); return tT; } - virtual Value* compile(CEnv& cenv); + Value* compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { ASTPrimitive(const vector<AST*>& c, Instruction::BinaryOps o) : ASTCall(c), op(o) {} - virtual const Type* type(CEnv& cenv) const { return Type::FloatTy; } - virtual Value* compile(CEnv& cenv); - Instruction::BinaryOps op; -}; - -/// Function prototype (actual LLVM IR function prototype) -struct ASTPrototype { - ASTPrototype(vector<AST*> p=vector<AST*>()) : params(p) {} - vector<const Type*> argsType(CEnv& cenv) { - vector<const Type*> types; - FOREACH(vector<AST*>::const_iterator, p, params) - types.push_back((*p)->type(cenv)); - return types; - } - virtual const Type* type(CEnv& cenv) const { return NULL; } - Function* compile(CEnv& cenv, FunctionType* type, const string& name); - string name; - vector<AST*> params; -}; - -/// Closure (first-class function with captured lexical bindings) -struct ASTClosure : public AST { - ASTClosure(ASTPrototype* p, AST* b) : prot(p), body(b), func(0) {} - virtual const Type* type(CEnv& cenv) const { - return FunctionType::get(body->type(cenv), prot->argsType(cenv), false); + AType* type(CEnv& cenv) { + if (tup.size() <= 1) throw SyntaxError("Primitive call with no arguments"); + return tup[1]->type(cenv); // FIXME: Ensure argument types are equivalent } - virtual Value* compile(CEnv& cenv); - virtual void lift(CEnv& cenv); - ASTPrototype* const prot; - AST* const body; - vector<const ASTSymbol*> bindings; -private: - Function* func; -}; - -/// Function definition (actual LLVM IR function) -struct ASTFunction { - ASTFunction(ASTPrototype* p, AST* b) : prot(p), body(b) {} - Function* compile(CEnv& cenv, const string& name); - ASTPrototype* const prot; - AST* const body; + Value* compile(CEnv& cenv); + Instruction::BinaryOps op; }; +AType* +ASTTuple::type(CEnv& cenv) +{ + vector<AST*> texp; + FOREACH(vector<AST*>::const_iterator, p, tup) + texp.push_back((*p)->type(cenv)); + return new AType(texp); +} /*************************************************************************** * Parser - S-Expressions (SExp) -> AST Nodes (AST) * ***************************************************************************/ -typedef unsigned UD; // User Data passed to registered parse functions - -// Parse Time Environment (just a symbol table) -struct PEnv : private map<const string, ASTSymbol*> { - typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function - struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; - map<const ASTSymbol*, Parser> parsers; - void reg(const ASTSymbol* s, const Parser& p) { - parsers.insert(make_pair(s, p)); - } - const Parser* parser(const ASTSymbol* s) const { - map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s); - return (i != parsers.end()) ? &i->second : NULL; - } - ASTSymbol* sym(const string& s) { - const const_iterator i = find(s); - return ((i != end()) - ? i->second - : insert(make_pair(s, new ASTSymbol(s))).first->second); - } -}; - /// The fundamental parser method static AST* parseExpression(PEnv& penv, const SExp& exp) @@ -330,9 +442,9 @@ static AST* parsePrim(PEnv& penv, const list<SExp>& c, UD data) { return new ASTPrimitive(pmap(penv, c), (Instruction::BinaryOps)data); } -static ASTPrototype* +static ASTTuple* parsePrototype(PEnv& penv, const SExp& e, UD) - { return new ASTPrototype(pmap(penv, e.list)); } + { return new ASTTuple(pmap(penv, e.list)); } static AST* parseFn(PEnv& penv, const list<SExp>& c, UD) @@ -348,143 +460,126 @@ parseFn(PEnv& penv, const list<SExp>& c, UD) * Code Generation * ***************************************************************************/ -/// Generic Recursive Environment (stack of key:value dictionaries) -template<typename K, typename V> -struct Env : public list< map<K,V> > { - Env() : list< map<K, V> >(1) {} - void push() { this->push_front(map<K,V>()); } - void push(const map<K,V>& frame) { this->push_front(frame); } - map<K,V>& pop() { - map<K,V>& front = this->front(); - this->pop_front(); - return front; - } - void def(const K& k, const V& v) { - if (this->front().find(k) != this->front().end()) - std::cerr << "WARNING: Redefinition: " << k << endl; - this->front()[k] = v; - } - V* ref(const K& name) { - typename Env::iterator i = this->begin(); - for (; i != this->end(); ++i) { - typename map<K,V>::iterator s = i->find(name); - if (s != i->end()) - return &s->second; - } - return 0; - } -}; +struct CompileError : public Error { CompileError(const char* m) : Error(m) {} }; -/// Compile-time environment -struct CEnv { - CEnv(Module* m, const TargetData* target) - : module(m), provider(module), fpm(&provider), symID(0) - { - // Set up the optimizer pipeline. - // Register info about how the target lays out data structures. - fpm.add(new TargetData(*target)); - // Do simple "peephole" and bit-twiddling optimizations. - fpm.add(createInstructionCombiningPass()); - // Reassociate expressions. - fpm.add(createReassociatePass()); - // Eliminate Common SubExpressions. - fpm.add(createGVNPass()); - // Simplify control flow graph (delete unreachable blocks, etc). - fpm.add(createCFGSimplificationPass()); - } - string gensym(const char* base="_") { - ostringstream s; s << base << symID++; return s.str(); - } - void def(const ASTSymbol* sym, AST* expr) { - types.def(sym, expr->type(*this)); - code.def(sym, expr); +static Function* +compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT) +{ + Function::LinkageTypes linkage = Function::ExternalLinkage; + + const vector<AST*>& texp = prot.type(cenv)->tup; + vector<const Type*> cprot; + for (size_t i = 0; i < texp.size(); ++i) { + const Type* t = texp[i]->type(cenv)->ctype; + if (!t) throw CompileError("Function prototype contains NULL"); + cprot.push_back(t); } - typedef Env<const ASTSymbol*, const Type*> Types; - typedef Env<const ASTSymbol*, AST*> Code; - typedef Env<const ASTSymbol*, Value*> Vals; - IRBuilder<> builder; - Module* module; - ExistingModuleProvider provider; - FunctionPassManager fpm; - size_t symID; - Types types; - Code code; - Vals vals; -}; + FunctionType* fT = FunctionType::get(retT, cprot, false); + Function* f = Function::Create(fT, linkage, name, cenv.module); -static void -lambdaLift(CEnv& env, AST* ast) -{ - if (ASTClosure* closure = dynamic_cast<ASTClosure*>(ast)) { - lambdaLift(env, closure->body); - closure->lift(env); - } else if (ASTCall* call = dynamic_cast<ASTCall*>(ast)) { - FOREACH(vector<AST*>::const_iterator, a, call->code) - lambdaLift(env, *a); + if (f->getName() != name) { + f->eraseFromParent(); + throw CompileError("Function redefined"); } + + // Set argument names in generated code + Function::arg_iterator a = f->arg_begin(); + for (size_t i = 0; i != prot.tup.size(); ++a, ++i) + a->setName(prot.tup[i]->str(cenv)); + + return f; } -const Type* -ASTSymbol::type(CEnv& cenv) const +AType* +ASTSymbol::type(CEnv& cenv) { - const Type** t = cenv.types.ref(this); - if (t) { - return *t; - } else { - //std::cerr << "WARNING: Untyped symbol: " << name << endl; - return Type::FloatTy; - } + AST** t = cenv.code.ref(this); + return t ? (*t)->type(cenv) : new AType(cenv.tID++); } Value* ASTSymbol::compile(CEnv& cenv) { - Value*const* v = cenv.vals.ref(this); + Value** v = cenv.vals.ref(this); if (v) return *v; - AST*const* c = cenv.code.ref(this); + AST** c = cenv.code.ref(this); if (c) { Value* v = (*c)->compile(cenv); cenv.vals.def(this, v); return v; } - throw SyntaxError((string("Undefined symbol '") + name + "'").c_str()); + throw SyntaxError((string("Undefined symbol '") + cppstr + "'").c_str()); } Value* ASTDefinition::compile(CEnv& cenv) { - if (code.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments"); - const ASTSymbol* sym = dynamic_cast<const ASTSymbol*>(code[1]); + if (tup.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments"); + const ASTSymbol* sym = dynamic_cast<const ASTSymbol*>(tup[1]); if (!sym) throw SyntaxError("Definition name is not a symbol"); - Value* val = code[2]->compile(cenv); - cenv.types.def(sym, code[2]->type(cenv)); - cenv.code.def(sym, code[2]); + Value* val = tup[2]->compile(cenv); + cenv.code.def(sym, tup[2]); cenv.vals.def(sym, val); return val; } +void +ASTCall::lift(CEnv& cenv) +{ + ASTClosure* c = dynamic_cast<ASTClosure*>(tup[0]); + if (!c) { + AST** val = cenv.code.ref(tup[0]); + c = (val) ? dynamic_cast<ASTClosure*>(*val) : c; + } + + if (!c) { + ASTTuple::lift(cenv); + return; + } + + std::cout << "Lifting call to closure" << endl; + + // Lift arguments + for (size_t i = 1; i < tup.size(); ++i) + tup[i]->lift(cenv); + + // Extend environment with bound and typed parameters + cenv.code.push(); + if (c->prot->tup.size() != tup.size() - 1) + throw CompileError("Call to closure with mismatched arguments"); + + for (size_t i = 1; i < tup.size(); ++i) + cenv.code.def(c->prot->tup[i-1], tup[i]); + + // Lift callee closure + tup[0]->lift(cenv); + + cenv.code.pop(); +} + Value* ASTCall::compile(CEnv& cenv) { - AST* func = code[0]; - AST** closure = cenv.code.ref((ASTSymbol*)func); - assert(closure); - ASTClosure* c = dynamic_cast<ASTClosure*>(*closure); - assert(c); - Function* f = dynamic_cast<Function*>(func->compile(cenv)); - if (!f) throw SyntaxError("Call to non-function"); + ASTClosure* c = dynamic_cast<ASTClosure*>(tup[0]); + if (!c) { + AST** val = cenv.code.ref(tup[0]); + c = (val) ? dynamic_cast<ASTClosure*>(*val) : c; + } + + if (!c) throw CompileError("Call to non-closure"); + Value* v = c->compile(cenv); + if (!v) throw SyntaxError("Callee failed to compile"); + Function* f = dynamic_cast<Function*>(c->compile(cenv)); + if (!f) throw SyntaxError("Callee compiled to non-function"); vector<Value*> params; - for (size_t i = 1; i < code.size(); ++i) - params.push_back(code[i]->compile(cenv)); - - for (size_t i = 0; i < c->bindings.size(); ++i) - std::cout << "BINDING: " << c->bindings[i]->name << endl; + for (size_t i = 1; i < tup.size(); ++i) + params.push_back(tup[i]->compile(cenv)); return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } @@ -492,7 +587,7 @@ ASTCall::compile(CEnv& cenv) Value* ASTIf::compile(CEnv& cenv) { - Value* condV = code[1]->compile(cenv); + Value* condV = tup[1]->compile(cenv); Function* parent = cenv.builder.GetInsertBlock()->getParent(); // Create blocks for the then and else cases. @@ -505,7 +600,7 @@ ASTIf::compile(CEnv& cenv) // Emit then value. cenv.builder.SetInsertPoint(thenBB); - Value* thenV = code[2]->compile(cenv); + Value* thenV = tup[2]->compile(cenv); cenv.builder.CreateBr(mergeBB); // compile of 'Then' can change the current block, update thenBB @@ -514,7 +609,7 @@ ASTIf::compile(CEnv& cenv) // Emit else block. parent->getBasicBlockList().push_back(elseBB); cenv.builder.SetInsertPoint(elseBB); - Value* elseV = code[3]->compile(cenv); + Value* elseV = tup[3]->compile(cenv); cenv.builder.CreateBr(mergeBB); // compile of 'Else' can change the current block, update elseBB @@ -523,7 +618,7 @@ ASTIf::compile(CEnv& cenv) // Emit merge block. parent->getBasicBlockList().push_back(mergeBB); cenv.builder.SetInsertPoint(mergeBB); - PHINode* pn = cenv.builder.CreatePHI(type(cenv), "iftmp"); + PHINode* pn = cenv.builder.CreatePHI(type(cenv)->ctype, "iftmp"); pn->addIncoming(thenV, thenBB); pn->addIncoming(elseV, elseBB); @@ -535,20 +630,10 @@ ASTClosure::lift(CEnv& cenv) { assert(!func); - //set<const ASTSymbol*> unbound; - const ASTCall* call = dynamic_cast<const ASTCall*>(body); - if (call) { - std::cout << "LIFT CALL BODY\n"; - } - -#if 0 - Env<const ASTSymbol*, const AST*> paramsEnv; - for (vector<const ASTSymbol*>::const_iterator p = prot->params.begin(); - p != prot->params.end(); ++p) { - //paramsEnv[*p] = NULL; - std::cout << "PARAM: " << (*p)->name << endl; - } -#endif + // Can't lift a closure with variable types (lift later when called) + for (size_t i = 0; i < prot->tup.size(); ++i) + if (prot->tup[i]->type(cenv)->var) + return; cenv.code.push(); @@ -557,21 +642,19 @@ ASTClosure::lift(CEnv& cenv) AST** obj = cenv.code.ref(sym); if (!obj) { std::cout << "UNDEFINED SYMBOL BODY\n"; - prot->params.push_back(sym); + prot->tup.push_back(sym); bindings.push_back(sym); } } // Write function declaration - Function* f = prot->compile(cenv, - FunctionType::get(body->type(cenv), prot->argsType(cenv), false), - cenv.gensym("_fn")); + Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, body->type(cenv)->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); - + // Bind argument values in CEnv vector<Value*> args; - vector<AST*>::const_iterator p = prot->params.begin(); + vector<AST*>::const_iterator p = prot->tup.begin(); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a); @@ -582,7 +665,7 @@ ASTClosure::lift(CEnv& cenv) verifyFunction(*f); // Validate generated code cenv.fpm.run(*f); // Optimize function func = f; - } catch (SyntaxError e) { + } catch (exception e) { f->eraseFromParent(); // Error reading body, remove function throw e; } @@ -594,77 +677,16 @@ Value* ASTClosure::compile(CEnv& cenv) { // Function was already compiled in the lifting pass - assert(func); return func; } -Function* -ASTPrototype::compile(CEnv& cenv, FunctionType* FT, const std::string& n) -{ - name = n; - Function::LinkageTypes linkage = Function::ExternalLinkage; - Function* f = Function::Create(FT, linkage, name, cenv.module); - - // If F conflicted, there was already something named 'Name'. - // If it has a body, don't allow redefinition. - if (f->getName() != name) { - // Delete the one we just made and get the existing one. - f->eraseFromParent(); - f = cenv.module->getFunction(name); - - // If F already has a body, reject this. - if (!f->empty()) throw SyntaxError("Function redefined"); - - // If F took a different number of args, reject. - if (f->arg_size() != params.size()) - throw SyntaxError("Function redefined with mismatched arguments"); - } - - // Set argument names in generated code - Function::arg_iterator a = f->arg_begin(); - for (size_t i = 0; i != params.size(); ++a, ++i) { - assert(params[i]); - ASTSymbol* sym = dynamic_cast<ASTSymbol*>(params[i]); - a->setName(sym ? sym->name : cenv.gensym("_a")); - } - - return f; -} - -Function* -ASTFunction::compile(CEnv& cenv, const string& name) -{ - const Type* bodyT = body->type(cenv); - if (dynamic_cast<const FunctionType*>(bodyT)) { - std::cout << "First class function alert" << endl; - bodyT = PointerType::get(bodyT, 0); - } - FunctionType* fT = FunctionType::get(bodyT, prot->argsType(cenv), false); - Function* f = prot->compile(cenv, fT, name); - - // Create a new basic block to start insertion into. - BasicBlock* bb = BasicBlock::Create("entry", f); - cenv.builder.SetInsertPoint(bb); - - try { - Value* retVal = body->compile(cenv); - cenv.builder.CreateRet(retVal); // Finish function - verifyFunction(*f); // Validate generated code - cenv.fpm.run(*f); // Optimize function - return f; - } catch (SyntaxError e) { - f->eraseFromParent(); // Error reading body, remove function - throw e; - } -} - Value* ASTPrimitive::compile(CEnv& cenv) { size_t np = 0; - vector<Value*> params(code.size() - 1); - vector<AST*>::const_iterator a = code.begin(); - for (++a; a != code.end(); ++a) + vector<Value*> params(tup.size() - 1); + vector<AST*>::const_iterator a = tup.begin(); + for (++a; a != tup.end(); ++a) params[np++] = (*a)->compile(cenv); switch (params.size()) { @@ -689,14 +711,10 @@ ASTPrimitive::compile(CEnv& cenv) int main() { - Module* module = new Module("interactive"); - ExecutionEngine* engine = ExecutionEngine::create(module); - CEnv cenv(module, engine->getTargetData()); - PEnv penv; + penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, 0)); penv.reg(penv.sym("if"), PEnv::Parser(parseIf, 0)); penv.reg(penv.sym("def"), PEnv::Parser(parseDef, 0)); - penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, 0)); penv.reg(penv.sym("+"), PEnv::Parser(parsePrim, Instruction::Add)); penv.reg(penv.sym("-"), PEnv::Parser(parsePrim, Instruction::Sub)); penv.reg(penv.sym("*"), PEnv::Parser(parsePrim, Instruction::Mul)); @@ -705,8 +723,13 @@ main() penv.reg(penv.sym("&"), PEnv::Parser(parsePrim, Instruction::And)); penv.reg(penv.sym("|"), PEnv::Parser(parsePrim, Instruction::Or)); penv.reg(penv.sym("^"), PEnv::Parser(parsePrim, Instruction::Xor)); - cenv.def(penv.sym("true"), new ASTLiteral<bool>(true)); - cenv.def(penv.sym("false"), new ASTLiteral<bool>(false)); + + Module* module = new Module("repl"); + ExecutionEngine* engine = ExecutionEngine::create(module); + CEnv cenv(penv, module, engine->getTargetData()); + + cenv.code.def(penv.sym("true"), new ASTLiteral<bool>(true)); + cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false)); while (1) { std::cout << "(=>) "; @@ -716,21 +739,47 @@ main() break; try { - AST* ast = parseExpression(penv, exp); - lambdaLift(cenv, ast); - ASTPrototype* proto = new ASTPrototype(); - ASTFunction* func = new ASTFunction(proto, ast); - Function* code = func->compile(cenv, cenv.gensym("_repl")); - void* fp = engine->getPointerToFunction(code); - code->dump(); - double (*f)() = (double (*)())fp; - std::cout << f() << " :: "; - func->body->type(cenv)->print(std::cout); - std::cout << endl; + AST* body = parseExpression(penv, exp); + ASTTuple* prot = new ASTTuple(); + AType* bodyT = body->type(cenv); + + if (!bodyT) throw TypeError("REPL call to untyped body"); + if (bodyT->var) throw TypeError("REPL call to variable typed body"); + + body->lift(cenv); + + if (bodyT->ctype) { + // Create anonymous function to insert code into. + Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype); + BasicBlock* bb = BasicBlock::Create("entry", f); + cenv.builder.SetInsertPoint(bb); + + try { + Value* retVal = body->compile(cenv); + cenv.builder.CreateRet(retVal); // Finish function + verifyFunction(*f); // Validate generated code + cenv.fpm.run(*f); // Optimize function + } catch (SyntaxError e) { + f->eraseFromParent(); // Error reading body, remove function + throw e; + } + + void* fp = engine->getPointerToFunction(f); + double (*cfunc)() = (double (*)())fp; + std::cout << cfunc(); + + } else { + Value* val = body->compile(cenv); + std::cout << val; + } + std::cout << " :: " << body->type(cenv)->str(cenv) << endl; + } catch (SyntaxError e) { std::cerr << "Syntax error: " << e.what() << endl; } catch (TypeError e) { std::cerr << "Type error: " << e.what() << endl; + } catch (CompileError e) { + std::cerr << "Compile error: " << e.what() << endl; } } |