/* A Trivial LLVM LISP * Copyright (C) 2008-2009 David Robillard * * Parts from the Kaleidoscope tutorial * by Chris Lattner and Erick Tryzelaar * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with This program. If not, see . */ #include #include #include #include #include #include #include #include #include "llvm/Analysis/Verifier.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" #define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i) using namespace llvm; using namespace std; struct Error : public std::exception { Error(const char* m) : msg(m) {} const char* what() const throw() { return msg; } const char* msg; }; template struct Exp { // ::= Atom | (Exp*) Exp() : type(LIST) {} Exp(const A& a) : type(ATOM), atom(a) {} enum { ATOM, LIST } type; typedef std::vector< Exp > List; A atom; List list; }; /*************************************************************************** * S-Expression Lexer :: text -> S-Expressions (SExp) * ***************************************************************************/ struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} }; typedef Exp SExp; static SExp readExpression(std::istream& in) { #define PUSH(s, t) { if (t != "") { s.top().list.push_back(t); t = ""; } } #define YIELD(s, t) { if (s.empty()) return t; else PUSH(s, t) } stack stk; string tok; while (char ch = in.get()) { switch (ch) { case EOF: return SExp(); case ' ': case '\t': case '\n': if (tok != "") YIELD(stk, tok); break; case '"': do { tok.push_back(ch); } while ((ch = in.get()) != '"'); YIELD(stk, tok + '"'); break; case '(': stk.push(SExp()); break; case ')': switch (stk.size()) { case 0: throw SyntaxError("Unexpected ')'"); case 1: PUSH(stk, tok); return stk.top(); default: PUSH(stk, tok); SExp l = stk.top(); stk.pop(); stk.top().list.push_back(l); } break; default: tok += ch; } } switch (stk.size()) { case 0: return tok; case 1: return stk.top(); default: throw SyntaxError("Missing ')'"); } return SExp(); } /*************************************************************************** * Abstract Syntax Tree * ***************************************************************************/ struct TEnv; ///< Type-Time Environment struct CEnv; ///< Compile-Time Environment struct AType; ///< Abstract Type /// Base class for all AST nodes struct AST { virtual ~AST() {} virtual bool contains(AST* child) const { return false; } virtual bool operator!=(const AST& o) const { return !operator==(o); } virtual bool operator==(const AST& o) const = 0; virtual string str() const = 0; virtual void constrain(TEnv& tenv) const {} virtual void lift(CEnv& cenv) {} virtual Value* compile(CEnv& cenv) = 0; }; /// Literal template struct ASTLiteral : public AST { ASTLiteral(VT v) : val(v) {} bool operator==(const AST& rhs) const { const ASTLiteral* r = dynamic_cast*>(&rhs); return r && val == r->val; } string str() const { ostringstream s; s << val; return s.str(); } void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); const VT val; }; /// Symbol, e.g. "a" struct ASTSymbol : public AST { ASTSymbol(const string& s) : cppstr(s) {} bool operator==(const AST& rhs) const { return this == &rhs; } string str() const { return cppstr; } Value* compile(CEnv& cenv); private: const string cppstr; }; typedef vector TupV; /// Tuple (heterogeneous sequence of known length), e.g. "(a b c)" struct ASTTuple : public AST { ASTTuple(const TupV& t=TupV()) : tup(t) {} string str() const { string ret = "("; for (size_t i = 0; i != tup.size(); ++i) ret += tup[i]->str() + ((i != tup.size() - 1) ? " " : ""); return ret + ")"; } bool operator==(const AST& rhs) const { const ASTTuple* rt = dynamic_cast(&rhs); if (!rt) return false; if (rt->tup.size() != tup.size()) return false; TupV::const_iterator l = tup.begin(); FOREACH(TupV::const_iterator, r, rt->tup) { AST* mine = *l++; AST* other = *r; if (!(*mine == *other)) return false; } return true; } void lift(CEnv& cenv) { FOREACH(TupV::iterator, t, tup) (*t)->lift(cenv); } bool isForm(const string& f) { return !tup.empty() && tup[0]->str() == f; } bool contains(AST* child) const; void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv) { return NULL; } TupV tup; }; static TupV tuple(AST* ast, ...) { TupV tup(1, ast); va_list args; va_start(args, ast); for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*)) tup.push_back(a); va_end(args); return tup; } /// Type Expression ::= (TName TExpr*) | ?Num struct AType : public ASTTuple { AType(const TupV& t) : ASTTuple(t), var(false), ctype(0) {} AType(unsigned i) : var(true), ctype(0), id(i) {} AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) { tup.push_back(n); } string str() const { if (var) { ostringstream s; s << "?" << id; return s.str(); } else { return ASTTuple::str(); } } void constrain(TEnv& tenv) const {} Value* compile(CEnv& cenv) { return NULL; } bool concrete() const { if (var) return false; FOREACH(TupV::const_iterator, t, tup) { AType* kid = dynamic_cast(*t); if (kid && !kid->concrete()) return false; } return true; } bool operator==(const AST& rhs) const { const AType* rt = dynamic_cast(&rhs); if (!rt) return false; else if (var && rt->var) return id == rt->id; else if (!var && !rt->var) return ASTTuple::operator==(rhs); return false; } bool var; const Type* ctype; unsigned id; }; /// Closure (first-class function with captured lexical bindings) struct ASTClosure : public ASTTuple { ASTClosure(ASTTuple* p, AST* b) : ASTTuple(tuple(0, p, b)), prot(p), func(0) {} bool operator==(const AST& rhs) const { return this == &rhs; } string str() const { ostringstream s; s << this; return s.str(); } void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); ASTTuple* const prot; private: Function* func; }; /// Function call/application, e.g. "(func arg1 arg2)" struct ASTCall : public ASTTuple { ASTCall(const TupV& t) : ASTTuple(t) {} void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); }; /// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" struct ASTDefinition : public ASTCall { ASTDefinition(const TupV& t) : ASTCall(t) {} void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { ASTIf(const TupV& t) : ASTCall(t) {} void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { ASTPrimitive(const TupV& t, int o, int a=0) : ASTCall(t), op(o), arg(a) {} void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); unsigned op; unsigned arg; }; /*************************************************************************** * Parser - S-Expressions (SExp) -> AST Nodes (AST) * ***************************************************************************/ /// LLVM Operation struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; }; typedef Op UD; // User Data argument for parse functions // Parse Time Environment (symbol table) struct PEnv : private map { typedef AST* (*PF)(PEnv&, const SExp::List&, UD); // Parse Function struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; map parsers; void reg(const string& s, const Parser& p) { parsers.insert(make_pair(sym(s)->str(), p)); } const Parser* parser(const string& s) const { map::const_iterator i = parsers.find(s); return (i != parsers.end()) ? &i->second : NULL; } ASTSymbol* sym(const string& s) { const const_iterator i = find(s); return ((i != end()) ? i->second : insert(make_pair(s, new ASTSymbol(s))).first->second); } }; /// The fundamental parser method static AST* parseExpression(PEnv& penv, const SExp& exp); static TupV pmap(PEnv& penv, const SExp::List& l) { TupV ret(l.size()); size_t n = 0; FOREACH(SExp::List::const_iterator, i, l) ret[n++] = parseExpression(penv, *i); return ret; } static AST* parseExpression(PEnv& penv, const SExp& exp) { if (exp.type == SExp::LIST) { if (exp.list.empty()) throw SyntaxError("Call to empty list"); if (exp.list.front().type == SExp::ATOM) { const PEnv::Parser* handler = penv.parser(exp.list.front().atom); if (handler) // Dispatch to parse function return handler->pf(penv, exp.list, handler->ud); } return new ASTCall(pmap(penv, exp.list)); // Parse as regular call } else if (isdigit(exp.atom[0])) { if (exp.atom.find('.') == string::npos) return new ASTLiteral(strtol(exp.atom.c_str(), NULL, 10)); else return new ASTLiteral(strtod(exp.atom.c_str(), NULL)); } return penv.sym(exp.atom); } // Special forms static AST* parseIf(PEnv& penv, const SExp::List& c, UD) { return new ASTIf(pmap(penv, c)); } static AST* parseDef(PEnv& penv, const SExp::List& c, UD) { return new ASTDefinition(pmap(penv, c)); } static AST* parsePrim(PEnv& penv, const SExp::List& c, UD data) { return new ASTPrimitive(pmap(penv, c), data.op, data.arg); } static AST* parseFn(PEnv& penv, const SExp::List& c, UD) { SExp::List::const_iterator a = c.begin(); ++a; return new ASTClosure( new ASTTuple(pmap(penv, (*a++).list)), parseExpression(penv, *a++)); } /*************************************************************************** * Lexical Environment * ***************************************************************************/ template struct Env : public list< map > { typedef map Frame; Env() : list(1) {} void push_front() { list::push_front(Frame()); } void def(const K& k, const V& v) { if (this->front().find(k) != this->front().end()) throw SyntaxError("Redefinition"); this->front()[k] = v; } V* ref(const K& name) { typename Frame::iterator s; for (typename Env::iterator i = this->begin(); i != this->end(); ++i) if ((s = i ->find(name)) != i->end()) return &s->second; return 0; } }; /*************************************************************************** * Typing * ***************************************************************************/ struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; struct TSubst : public map { TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); } }; /// Type-Time Environment struct TEnv { TEnv(PEnv& p) : penv(p), varID(1) {} typedef map Types; typedef list< pair > Constraints; AType* var() { return new AType(varID++); } AType* type(const AST* ast) { Types::iterator t = types.find(ast); return (t != types.end()) ? t->second : (types[ast] = var()); } AType* named(const string& name) const { Types::const_iterator i = namedTypes.find(penv.sym(name)); if (i == namedTypes.end()) throw TypeError("Unknown named type"); return i->second; } void name(const string& name, const Type* type) { ASTSymbol* sym = penv.sym(name); namedTypes[sym] = new AType(penv.sym(name), type); } void constrain(const AST* o, AType* t) { constraints.push_back(make_pair(type(o), t)); } void solve() { apply(unify(constraints)); } void apply(const TSubst& substs); static TSubst unify(const Constraints& c); PEnv& penv; Types types; Types namedTypes; Constraints constraints; unsigned varID; }; #define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) void ASTTuple::constrain(TEnv& tenv) const { TupV texp; FOREACH(TupV::const_iterator, p, tup) { (*p)->constrain(tenv); texp.push_back(tenv.type(*p)); } AType* t = tenv.type(this); t->var = false; t->tup = texp; } void ASTClosure::constrain(TEnv& tenv) const { prot->constrain(tenv); tup[2]->constrain(tenv); AType* bodyT = tenv.type(tup[2]); tenv.constrain(this, new AType(tuple( tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0))); } void ASTCall::constrain(TEnv& tenv) const { FOREACH(TupV::const_iterator, p, tup) (*p)->constrain(tenv); AType* retT = tenv.type(this); TupV texp = tuple(tenv.penv.sym("Fn"), tenv.var(), retT, NULL); tenv.constrain(tup[0], new AType(texp)); } void ASTDefinition::constrain(TEnv& tenv) const { FOREACH(TupV::const_iterator, p, tup) (*p)->constrain(tenv); AType* tvar = tenv.type(this); tenv.constrain(tup[1], tvar); tenv.constrain(tup[2], tvar); } void ASTIf::constrain(TEnv& tenv) const { FOREACH(TupV::const_iterator, p, tup) (*p)->constrain(tenv); AType* tvar = tenv.type(this); tenv.constrain(tup[1], tenv.named("Bool")); tenv.constrain(tup[2], tvar); tenv.constrain(tup[3], tvar); } void ASTPrimitive::constrain(TEnv& tenv) const { FOREACH(TupV::const_iterator, p, tup) (*p)->constrain(tenv); if (OP_IS_A(op, Instruction::BinaryOps)) { if (tup.size() <= 1) throw SyntaxError("Primitive call with 0 args"); AType* tvar = tenv.type(this); for (size_t i = 1; i < tup.size(); ++i) tenv.constrain(tup[i], tvar); } else if (op == Instruction::ICmp) { if (tup.size() != 3) throw SyntaxError("Comparison call with != 2 args"); tenv.constrain(tup[1], tenv.type(tup[2])); tenv.constrain(this, tenv.named("Bool")); } else { throw TypeError("Unknown primitive"); } } static void substitute(ASTTuple* tup, AST* from, AST* to) { if (!tup) return; for (size_t i = 0; i < tup->tup.size(); ++i) if (*tup->tup[i] == *from) tup->tup[i] = to; else substitute(dynamic_cast(tup->tup[i]), from, to); } bool ASTTuple::contains(AST* child) const { if (*this == *child) return true; FOREACH(TupV::const_iterator, p, tup) if (**p == *child || (*p)->contains(child)) return true; return false; } TSubst compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1 { TSubst r; for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) { TSubst::const_iterator i = delta.find(g->second); if (i != delta.end()) r.insert(make_pair(g->first, ((i != delta.end()) ? i : g)->second)); else r.insert(make_pair(g->first, g->second)); } for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) { if (gamma.find(d->first) == gamma.end()) r.insert(make_pair(d->first, d->second)); } return r; } void substConstraints(TEnv::Constraints& constraints, AType* s, AType* t) { for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) { TEnv::Constraints::iterator next = c; ++next; if (*c->first == *s) c->first = t; if (*c->second == *s) c->second = t; substitute(c->first, s, t); substitute(c->second, s, t); c = next; } } TSubst TEnv::unify(const Constraints& constraints) // TAPL 22.4 { if (constraints.empty()) return TSubst(); AType* s = constraints.begin()->first; AType* t = constraints.begin()->second; Constraints cp = constraints; cp.erase(cp.begin()); if (*s == *t) { return unify(cp); } else if (s->var && !t->contains(s)) { substConstraints(cp, s, t); return compose(unify(cp), TSubst(s, t)); } else if (t->var && !s->contains(t)) { substConstraints(cp, t, s); return compose(unify(cp), TSubst(t, s)); } else if (s->isForm("Fn") && t->isForm("Fn")) { AType* s1 = dynamic_cast(s->tup[1]); AType* t1 = dynamic_cast(t->tup[1]); AType* s2 = dynamic_cast(s->tup[2]); AType* t2 = dynamic_cast(t->tup[2]); assert(s1 && t1 && s2 && t2); cp.push_back(make_pair(s1, t1)); cp.push_back(make_pair(s2, t2)); return unify(cp); } else { throw TypeError("Type unification failed"); } } void TEnv::apply(const TSubst& substs) { FOREACH(TSubst::const_iterator, s, substs) FOREACH(Types::iterator, t, types) if (*t->second == *s->first) t->second = s->second; } /*************************************************************************** * Code Generation * ***************************************************************************/ struct CompileError : public Error { CompileError(const char* m) : Error(m) {} }; class PEnv; /// Compile-Time Environment struct CEnv { CEnv(PEnv& p, Module* m, const TargetData* target) : penv(p), tenv(p), module(m), emp(module), fpm(&emp), symID(0) { // Set up the optimizer pipeline: fpm.add(new TargetData(*target)); // Register target arch fpm.add(createInstructionCombiningPass()); // Simple optimizations fpm.add(createReassociatePass()); // Reassociate expressions fpm.add(createGVNPass()); // Eleminate Common Subexpressions fpm.add(createCFGSimplificationPass()); // Simplify control flow } string gensym(const char* base="_") { ostringstream s; s << base << symID++; return s.str(); } typedef Env Code; typedef Env Vals; PEnv& penv; TEnv tenv; IRBuilder<> builder; Module* module; ExistingModuleProvider emp; FunctionPassManager fpm; unsigned symID; Code code; Vals vals; }; #define LITERAL(CT, NAME, COMPILED) \ template<> Value* \ ASTLiteral::compile(CEnv& cenv) { return (COMPILED); } \ template<> void \ ASTLiteral::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } /// Literal template instantiations LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)); LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)); LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); static Function* compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT) { Function::LinkageTypes linkage = Function::ExternalLinkage; const TupV& texp = cenv.tenv.type(&prot)->tup; vector cprot; for (size_t i = 0; i < texp.size(); ++i) { const Type* t = cenv.tenv.type(texp[i])->ctype; if (!t) throw CompileError("Function prototype contains NULL"); cprot.push_back(t); } if (!retT) throw CompileError("Function return value type is NULL"); FunctionType* fT = FunctionType::get(retT, cprot, false); Function* f = Function::Create(fT, linkage, name, cenv.module); if (f->getName() != name) { f->eraseFromParent(); throw CompileError("Function redefined"); } // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); for (size_t i = 0; i != prot.tup.size(); ++a, ++i) a->setName(prot.tup[i]->str()); return f; } Value* ASTSymbol::compile(CEnv& cenv) { Value** v = cenv.vals.ref(this); if (v) return *v; AST** c = cenv.code.ref(this); if (c) { Value* v = (*c)->compile(cenv); cenv.vals.def(this, v); return v; } throw SyntaxError((string("Undefined symbol '") + cppstr + "'").c_str()); } Value* ASTDefinition::compile(CEnv& cenv) { if (tup.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments"); const ASTSymbol* sym = dynamic_cast(tup[1]); if (!sym) throw SyntaxError("Definition name is not a symbol"); Value* val = tup[2]->compile(cenv); cenv.code.def(sym, tup[2]); cenv.vals.def(sym, val); return val; } void ASTCall::lift(CEnv& cenv) { ASTClosure* c = dynamic_cast(tup[0]); if (!c) { AST** val = cenv.code.ref(tup[0]); c = (val) ? dynamic_cast(*val) : c; } // Lift arguments for (size_t i = 1; i < tup.size(); ++i) tup[i]->lift(cenv); if (!c) return; // Extend environment with bound and typed parameters cenv.code.push_front(); if (c->prot->tup.size() != tup.size() - 1) throw CompileError("Call to closure with mismatched arguments"); for (size_t i = 1; i < tup.size(); ++i) cenv.code.def(c->prot->tup[i-1], tup[i]); tup[0]->lift(cenv); // Lift called closure cenv.code.pop_front(); // Restore environment } Value* ASTCall::compile(CEnv& cenv) { ASTClosure* c = dynamic_cast(tup[0]); if (!c) { AST** val = cenv.code.ref(tup[0]); c = (val) ? dynamic_cast(*val) : c; } if (!c) throw CompileError("Call to non-closure"); Value* v = c->compile(cenv); if (!v) throw CompileError("Callee failed to compile"); Function* f = dynamic_cast(c->compile(cenv)); if (!f) throw CompileError("Callee compiled to non-function"); vector params; for (size_t i = 1; i < tup.size(); ++i) params.push_back(tup[i]->compile(cenv)); return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } Value* ASTIf::compile(CEnv& cenv) { Value* condV = tup[1]->compile(cenv); Function* parent = cenv.builder.GetInsertBlock()->getParent(); // Create blocks for the then and else cases. // Insert the 'then' block at the end of the function. BasicBlock* thenBB = BasicBlock::Create("then", parent); BasicBlock* elseBB = BasicBlock::Create("else"); BasicBlock* mergeBB = BasicBlock::Create("ifcont"); cenv.builder.CreateCondBr(condV, thenBB, elseBB); // Emit then block cenv.builder.SetInsertPoint(thenBB); Value* thenV = tup[2]->compile(cenv); // Can change current block, so... cenv.builder.CreateBr(mergeBB); thenBB = cenv.builder.GetInsertBlock(); // ... update thenBB afterwards // Emit else block parent->getBasicBlockList().push_back(elseBB); cenv.builder.SetInsertPoint(elseBB); Value* elseV = tup[3]->compile(cenv); // Can change current block, so... cenv.builder.CreateBr(mergeBB); elseBB = cenv.builder.GetInsertBlock(); // ... update elseBB afterwards // Emit merge block parent->getBasicBlockList().push_back(mergeBB); cenv.builder.SetInsertPoint(mergeBB); PHINode* pn = cenv.builder.CreatePHI(cenv.tenv.type(this)->ctype, "iftmp"); pn->addIncoming(thenV, thenBB); pn->addIncoming(elseV, elseBB); return pn; } void ASTClosure::lift(CEnv& cenv) { // Can't lift a closure with variable types (lift later when called) if (cenv.tenv.type(tup[2])->var) return; for (size_t i = 0; i < prot->tup.size(); ++i) if (cenv.tenv.type(prot->tup[i])->var) return; assert(!func); cenv.code.push_front(); // Write function declaration Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(tup[2])->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); // Bind argument values in CEnv vector args; TupV::const_iterator p = prot->tup.begin(); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) cenv.vals.def(dynamic_cast(*p), &*a); // Write function body try { Value* retVal = tup[2]->compile(cenv); cenv.builder.CreateRet(retVal); // Finish function verifyFunction(*f); // Validate generated code cenv.fpm.run(*f); // Optimize function func = f; } catch (exception e) { f->eraseFromParent(); // Error reading body, remove function throw e; } cenv.code.pop_front(); } Value* ASTClosure::compile(CEnv& cenv) { return func; // Function was already compiled in the lifting pass } Value* ASTPrimitive::compile(CEnv& cenv) { if (tup.size() < 3) throw SyntaxError("Too few arguments"); Value* a = tup[1]->compile(cenv); Value* b = tup[2]->compile(cenv); if (OP_IS_A(op, Instruction::BinaryOps)) { const Instruction::BinaryOps bo = (Instruction::BinaryOps)op; if (tup.size() == 2) return tup[1]->compile(cenv); Value* val = cenv.builder.CreateBinOp(bo, a, b); for (size_t i = 3; i < tup.size(); ++i) val = cenv.builder.CreateBinOp(bo, val, tup[i]->compile(cenv)); return val; } else if (op == Instruction::ICmp) { bool isInt = cenv.tenv.type(tup[1])->str() == "(Int)"; if (isInt) { return cenv.builder.CreateICmp((CmpInst::Predicate)arg, a, b); } else { // Translate to floating point operation switch (arg) { case CmpInst::ICMP_EQ: arg = CmpInst::FCMP_OEQ; break; case CmpInst::ICMP_NE: arg = CmpInst::FCMP_ONE; break; case CmpInst::ICMP_SGT: arg = CmpInst::FCMP_OGT; break; case CmpInst::ICMP_SGE: arg = CmpInst::FCMP_OGE; break; case CmpInst::ICMP_SLT: arg = CmpInst::FCMP_OLT; break; case CmpInst::ICMP_SLE: arg = CmpInst::FCMP_OLE; break; default: throw CompileError("Unknown primitive"); } return cenv.builder.CreateFCmp((CmpInst::Predicate)arg, a, b); } } throw CompileError("Unknown primitive"); } /*************************************************************************** * REPL * ***************************************************************************/ int main() { #define PRIM(O, A) PEnv::Parser(parsePrim, Op(Instruction:: O, A)) PEnv penv; penv.reg("fn", PEnv::Parser(parseFn, Op())); penv.reg("if", PEnv::Parser(parseIf, Op())); penv.reg("def", PEnv::Parser(parseDef, Op())); penv.reg("+", PRIM(Add, 0)); penv.reg("-", PRIM(Sub, 0)); penv.reg("*", PRIM(Mul, 0)); penv.reg("/", PRIM(FDiv, 0)); penv.reg("%", PRIM(FRem, 0)); penv.reg("&", PRIM(And, 0)); penv.reg("|", PRIM(Or, 0)); penv.reg("^", PRIM(Xor, 0)); penv.reg("=", PRIM(ICmp, CmpInst::ICMP_EQ)); penv.reg("!=", PRIM(ICmp, CmpInst::ICMP_NE)); penv.reg(">", PRIM(ICmp, CmpInst::ICMP_SGT)); penv.reg(">=", PRIM(ICmp, CmpInst::ICMP_SGE)); penv.reg("<", PRIM(ICmp, CmpInst::ICMP_SLT)); penv.reg("<=", PRIM(ICmp, CmpInst::ICMP_SLE)); Module* module = new Module("repl"); ExecutionEngine* engine = ExecutionEngine::create(module); CEnv cenv(penv, module, engine->getTargetData()); cenv.tenv.name("Bool", Type::Int1Ty); cenv.tenv.name("Int", Type::Int32Ty); cenv.tenv.name("Float", Type::FloatTy); cenv.code.def(penv.sym("true"), new ASTLiteral(true)); cenv.code.def(penv.sym("false"), new ASTLiteral(false)); while (1) { std::cout << "() "; std::cout.flush(); SExp exp = readExpression(std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; try { AST* body = parseExpression(penv, exp); // Parse input body->constrain(cenv.tenv); // Constrain types cenv.tenv.solve(); // Solve and apply type constraints AType* bodyT = cenv.tenv.type(body); if (!bodyT) throw TypeError("REPL call to untyped body"); if (bodyT->var) throw TypeError("REPL call to variable typed body"); body->lift(cenv); if (bodyT->ctype) { // Create anonymous function to insert code into. ASTTuple* prot = new ASTTuple(); Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); try { Value* retVal = body->compile(cenv); cenv.builder.CreateRet(retVal); // Finish function verifyFunction(*f); // Validate generated code cenv.fpm.run(*f); // Optimize function } catch (SyntaxError e) { f->eraseFromParent(); // Error reading body, remove function throw e; } void* fp = engine->getPointerToFunction(f); if (bodyT->ctype == Type::Int32Ty) std::cout << "; " << ((int32_t (*)())fp)(); else if (bodyT->ctype == Type::FloatTy) std::cout << "; " << ((float (*)())fp)(); else if (bodyT->ctype == Type::Int1Ty) std::cout << "; " << ((bool (*)())fp)(); } else { Value* val = body->compile(cenv); std::cout << "; " << val; } std::cout << " : " << cenv.tenv.type(body)->str() << endl; } catch (Error e) { std::cerr << "Error: " << e.what() << endl; } } std::cout << endl << "Generated code:" << endl; module->dump(); return 0; }