/* Tuplr: A minimalist programming language * Copyright (C) 2008-2009 David Robillard * * Tuplr is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Tuplr is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Tuplr. If not, see . */ #include #include #include #include #include #include #include #include #include "llvm/Analysis/Verifier.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" #define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i) using namespace llvm; using namespace std; struct Error : public std::exception { Error(const char* m) : msg(m) {} const char* what() const throw() { return msg; } const char* msg; }; template struct Exp { // ::= Atom | (Exp*) Exp() : type(LIST) {} Exp(const A& a) : type(ATOM), atom(a) {} enum { ATOM, LIST } type; typedef std::vector< Exp > List; A atom; List list; }; /*************************************************************************** * S-Expression Lexer :: text -> S-Expressions (SExp) * ***************************************************************************/ struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} }; typedef Exp SExp; static SExp readExpression(std::istream& in) { #define PUSH(s, t) { if (t != "") { s.top().list.push_back(t); t = ""; } } #define YIELD(s, t) { if (s.empty()) return t; else PUSH(s, t) } stack stk; string tok; while (char ch = in.get()) { switch (ch) { case EOF: return SExp(); case ' ': case '\t': case '\n': if (tok != "") YIELD(stk, tok); break; case '"': do { tok.push_back(ch); } while ((ch = in.get()) != '"'); YIELD(stk, tok + '"'); break; case '(': stk.push(SExp()); break; case ')': switch (stk.size()) { case 0: throw SyntaxError("Unexpected ')'"); case 1: PUSH(stk, tok); return stk.top(); default: PUSH(stk, tok); SExp l = stk.top(); stk.pop(); stk.top().list.push_back(l); } break; default: tok += ch; } } switch (stk.size()) { case 0: return tok; case 1: return stk.top(); default: throw SyntaxError("Missing ')'"); } return SExp(); } /*************************************************************************** * Abstract Syntax Tree * ***************************************************************************/ struct TEnv; ///< Type-Time Environment struct CEnv; ///< Compile-Time Environment /// Base class for all AST nodes struct AST { virtual ~AST() {} virtual bool contains(AST* child) const { return false; } virtual bool operator!=(const AST& o) const { return !operator==(o); } virtual bool operator==(const AST& o) const = 0; virtual string str() const = 0; virtual void constrain(TEnv& tenv) const {} virtual void lift(CEnv& cenv) {} virtual Value* compile(CEnv& cenv) = 0; }; /// Literal template struct ASTLiteral : public AST { ASTLiteral(VT v) : val(v) {} bool operator==(const AST& rhs) const { const ASTLiteral* r = dynamic_cast*>(&rhs); return r && val == r->val; } string str() const { ostringstream s; s << val; return s.str(); } void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); const VT val; }; /// Symbol, e.g. "a" struct ASTSymbol : public AST { ASTSymbol(const string& s) : cppstr(s) {} bool operator==(const AST& rhs) const { return this == &rhs; } string str() const { return cppstr; } Value* compile(CEnv& cenv); private: const string cppstr; }; /// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)" struct ASTTuple : public AST, public vector { ASTTuple(const vector& t=vector()) : vector(t) {} ASTTuple(size_t size) : vector(size) {} ASTTuple(AST* ast, ...) { push_back(ast); va_list args; va_start(args, ast); for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*)) push_back(a); va_end(args); } string str() const { string ret = "("; for (size_t i = 0; i != size(); ++i) ret += at(i)->str() + ((i != size() - 1) ? " " : ""); return ret + ")"; } bool operator==(const AST& rhs) const { const ASTTuple* rt = dynamic_cast(&rhs); if (!rt) return false; if (rt->size() != size()) return false; const_iterator l = begin(); FOREACH(const_iterator, r, *rt) { AST* mine = *l++; AST* other = *r; if (!(*mine == *other)) return false; } return true; } void lift(CEnv& cenv) { FOREACH(iterator, t, *this) (*t)->lift(cenv); } bool isForm(const string& f) { return !empty() && at(0)->str() == f; } bool contains(AST* child) const; void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv) { return NULL; } }; /// Type Expression, e.g. "(Int)" or "(Fn ((Int)) (Float))" struct AType : public ASTTuple { AType(const ASTTuple& t) : ASTTuple(t), var(false), ctype(0) {} AType(unsigned i) : var(true), ctype(0), id(i) {} AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) { push_back(n); } string str() const { if (var) { ostringstream s; s << "?" << id; return s.str(); } else { return ASTTuple::str(); } } void constrain(TEnv& tenv) const {} Value* compile(CEnv& cenv) { return NULL; } bool concrete() const { if (var) return false; FOREACH(const_iterator, t, *this) { AType* kid = dynamic_cast(*t); if (kid && !kid->concrete()) return false; } return true; } bool operator==(const AST& rhs) const { const AType* rt = dynamic_cast(&rhs); if (!rt) return false; else if (var && rt->var) return id == rt->id; else if (!var && !rt->var) return ASTTuple::operator==(rhs); return false; } bool var; const Type* ctype; unsigned id; }; /// Closure (first-class function with captured lexical bindings) struct ASTClosure : public ASTTuple { ASTClosure(ASTTuple* p, AST* b) : ASTTuple(0, p, b), prot(p), func(0) {} bool operator==(const AST& rhs) const { return this == &rhs; } string str() const { ostringstream s; s << this; return s.str(); } void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); ASTTuple* const prot; private: Function* func; }; /// Function call/application, e.g. "(func arg1 arg2)" struct ASTCall : public ASTTuple { ASTCall(const ASTTuple& t) : ASTTuple(t) {} void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); }; /// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" struct ASTDefinition : public ASTCall { ASTDefinition(const ASTTuple& t) : ASTCall(t) {} void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { ASTIf(const ASTTuple& t) : ASTCall(t) {} void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { ASTPrimitive(const ASTTuple& t, int o, int a=0) : ASTCall(t), op(o), arg(a) {} void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); unsigned op; unsigned arg; }; /*************************************************************************** * Parser - S-Expressions (SExp) -> AST Nodes (AST) * ***************************************************************************/ /// LLVM Operation struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; }; typedef Op UD; // User Data argument for parse functions // Parse Time Environment (symbol table) struct PEnv : private map { typedef AST* (*PF)(PEnv&, const SExp::List&, UD); // Parse Function struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; map parsers; void reg(const string& s, const Parser& p) { parsers.insert(make_pair(sym(s)->str(), p)); } const Parser* parser(const string& s) const { map::const_iterator i = parsers.find(s); return (i != parsers.end()) ? &i->second : NULL; } ASTSymbol* sym(const string& s) { const const_iterator i = find(s); return ((i != end()) ? i->second : insert(make_pair(s, new ASTSymbol(s))).first->second); } }; /// The fundamental parser method static AST* parseExpression(PEnv& penv, const SExp& exp); static ASTTuple pmap(PEnv& penv, const SExp::List& l) { ASTTuple ret(l.size()); size_t n = 0; FOREACH(SExp::List::const_iterator, i, l) ret[n++] = parseExpression(penv, *i); return ret; } static AST* parseExpression(PEnv& penv, const SExp& exp) { if (exp.type == SExp::LIST) { if (exp.list.empty()) throw SyntaxError("Call to empty list"); if (exp.list.front().type == SExp::ATOM) { const PEnv::Parser* handler = penv.parser(exp.list.front().atom); if (handler) // Dispatch to parse function return handler->pf(penv, exp.list, handler->ud); } return new ASTCall(pmap(penv, exp.list)); // Parse as regular call } else if (isdigit(exp.atom[0])) { if (exp.atom.find('.') == string::npos) return new ASTLiteral(strtol(exp.atom.c_str(), NULL, 10)); else return new ASTLiteral(strtod(exp.atom.c_str(), NULL)); } return penv.sym(exp.atom); } // Special forms static AST* parseIf(PEnv& penv, const SExp::List& c, UD) { return new ASTIf(pmap(penv, c)); } static AST* parseDef(PEnv& penv, const SExp::List& c, UD) { return new ASTDefinition(pmap(penv, c)); } static AST* parsePrim(PEnv& penv, const SExp::List& c, UD data) { return new ASTPrimitive(pmap(penv, c), data.op, data.arg); } static AST* parseFn(PEnv& penv, const SExp::List& c, UD) { SExp::List::const_iterator a = c.begin(); ++a; return new ASTClosure( new ASTTuple(pmap(penv, (*a++).list)), parseExpression(penv, *a++)); } /*************************************************************************** * Generic Lexical Environment * ***************************************************************************/ template struct Env : public list< map > { typedef map Frame; Env() : list(1) {} void push_front() { list::push_front(Frame()); } const V& def(const K& k, const V& v) { typename Frame::iterator existing = this->front().find(k); if (existing != this->front().end() && existing->second != v) throw SyntaxError("Redefinition"); return (this->front()[k] = v); } V* ref(const K& name) { typename Frame::iterator s; for (typename Env::iterator i = this->begin(); i != this->end(); ++i) if ((s = i ->find(name)) != i->end()) return &s->second; return 0; } }; /*************************************************************************** * Typing * ***************************************************************************/ struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; struct TSubst : public map { TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); } }; /// Type-Time Environment struct TEnv { TEnv(PEnv& p) : penv(p), varID(1) {} typedef map Types; typedef list< pair > Constraints; AType* var() { return new AType(varID++); } AType* type(const AST* ast) { Types::iterator t = types.find(ast); return (t != types.end()) ? t->second : (types[ast] = var()); } AType* named(const string& name) const { Types::const_iterator i = namedTypes.find(penv.sym(name)); if (i == namedTypes.end()) throw TypeError("Unknown named type"); return i->second; } void name(const string& name, const Type* type) { ASTSymbol* sym = penv.sym(name); namedTypes[sym] = new AType(penv.sym(name), type); } void constrain(const AST* o, AType* t) { constraints.push_back(make_pair(type(o), t)); } void solve() { apply(unify(constraints)); } void apply(const TSubst& substs); static TSubst unify(const Constraints& c); PEnv& penv; Types types; Types namedTypes; Constraints constraints; unsigned varID; }; #define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) void ASTTuple::constrain(TEnv& tenv) const { AType* t = new AType(ASTTuple()); FOREACH(const_iterator, p, *this) { (*p)->constrain(tenv); t->push_back(tenv.type(*p)); } tenv.constrain(tenv.type(this), t); } void ASTClosure::constrain(TEnv& tenv) const { prot->constrain(tenv); at(2)->constrain(tenv); AType* bodyT = tenv.type(at(2)); tenv.constrain(this, new AType( ASTTuple(tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0))); } void ASTCall::constrain(TEnv& tenv) const { FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* retT = tenv.type(this); tenv.constrain(at(0), new AType(ASTTuple( tenv.penv.sym("Fn"), tenv.var(), retT, NULL))); } void ASTDefinition::constrain(TEnv& tenv) const { if (size() != 3) throw SyntaxError("\"def\" not passed 2 arguments"); if (!dynamic_cast(at(1))) throw SyntaxError("\"def\" name is not a symbol"); FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* tvar = tenv.type(this); tenv.constrain(at(1), tvar); tenv.constrain(at(2), tvar); } void ASTIf::constrain(TEnv& tenv) const { FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* tvar = tenv.type(this); tenv.constrain(at(1), tenv.named("Bool")); tenv.constrain(at(2), tvar); tenv.constrain(at(3), tvar); } void ASTPrimitive::constrain(TEnv& tenv) const { FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); if (OP_IS_A(op, Instruction::BinaryOps)) { if (size() <= 1) throw SyntaxError("Primitive call with 0 args"); AType* tvar = tenv.type(this); for (size_t i = 1; i < size(); ++i) tenv.constrain(at(i), tvar); } else if (op == Instruction::ICmp) { if (size() != 3) throw SyntaxError("Comparison call with != 2 args"); tenv.constrain(at(1), tenv.type(at(2))); tenv.constrain(this, tenv.named("Bool")); } else { throw TypeError("Unknown primitive"); } } static void substitute(ASTTuple* tup, AST* from, AST* to) { if (!tup) return; for (size_t i = 0; i < tup->size(); ++i) if (*tup->at(i) == *from) tup->at(i) = to; else substitute(dynamic_cast(tup->at(i)), from, to); } bool ASTTuple::contains(AST* child) const { if (*this == *child) return true; FOREACH(const_iterator, p, *this) if (**p == *child || (*p)->contains(child)) return true; return false; } TSubst compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1 { TSubst r; for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) { TSubst::const_iterator d = delta.find(g->second); r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second)); } for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) { if (gamma.find(d->first) == gamma.end()) r.insert(*d); } return r; } void substConstraints(TEnv::Constraints& constraints, AType* s, AType* t) { for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) { TEnv::Constraints::iterator next = c; ++next; if (*c->first == *s) c->first = t; if (*c->second == *s) c->second = t; substitute(c->first, s, t); substitute(c->second, s, t); c = next; } } TSubst TEnv::unify(const Constraints& constraints) // TAPL 22.4 { if (constraints.empty()) return TSubst(); AType* s = constraints.begin()->first; AType* t = constraints.begin()->second; Constraints cp = constraints; cp.erase(cp.begin()); if (*s == *t) { return unify(cp); } else if (s->var && !t->contains(s)) { substConstraints(cp, s, t); return compose(unify(cp), TSubst(s, t)); } else if (t->var && !s->contains(t)) { substConstraints(cp, t, s); return compose(unify(cp), TSubst(t, s)); } else if (s->isForm("Fn") && t->isForm("Fn")) { AType* s1 = dynamic_cast(s->at(1)); AType* t1 = dynamic_cast(t->at(1)); AType* s2 = dynamic_cast(s->at(2)); AType* t2 = dynamic_cast(t->at(2)); assert(s1 && t1 && s2 && t2); cp.push_back(make_pair(s1, t1)); cp.push_back(make_pair(s2, t2)); return unify(cp); } else { throw TypeError("Type unification failed"); } } void TEnv::apply(const TSubst& substs) { FOREACH(TSubst::const_iterator, s, substs) FOREACH(Types::iterator, t, types) if (*t->second == *s->first) t->second = s->second; } /*************************************************************************** * Code Generation * ***************************************************************************/ struct CompileError : public Error { CompileError(const char* m) : Error(m) {} }; class PEnv; /// Compile-Time Environment struct CEnv { CEnv(PEnv& p, Module* m, const TargetData* target) : penv(p), tenv(p), module(m), emp(module), opt(&emp), symID(0) { // Set up the optimizer pipeline: opt.add(new TargetData(*target)); // Register target arch opt.add(createInstructionCombiningPass()); // Simple optimizations opt.add(createReassociatePass()); // Reassociate expressions opt.add(createGVNPass()); // Eliminate Common Subexpressions opt.add(createCFGSimplificationPass()); // Simplify control flow } string gensym(const char* base="_") { ostringstream s; s << base << symID++; return s.str(); } void push() { code.push_front(); vals.push_front(); } void pop() { code.pop_front(); vals.pop_front(); } Value* compile(AST* obj) { Value** v = vals.ref(obj); return (v) ? *v : vals.def(obj, obj->compile(*this)); } void precompile(AST* obj, Value* value) { assert(!vals.ref(obj)); vals.def(obj, value); } void optimise(Function& f) { verifyFunction(f); opt.run(f); } typedef Env Code; typedef Env Vals; PEnv& penv; TEnv tenv; IRBuilder<> builder; Module* module; ExistingModuleProvider emp; FunctionPassManager opt; unsigned symID; Code code; Vals vals; }; #define LITERAL(CT, NAME, COMPILED) \ template<> Value* \ ASTLiteral::compile(CEnv& cenv) { return (COMPILED); } \ template<> void \ ASTLiteral::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } /// Literal template instantiations LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)); LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)); LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); static Function* compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT) { Function::LinkageTypes linkage = Function::ExternalLinkage; vector cprot; for (size_t i = 0; i < prot.size(); ++i) { const AType* at = cenv.tenv.type(prot.at(i)); if (!at->ctype || at->var) throw CompileError("Parameter is untyped"); cprot.push_back(at->ctype); } if (!retT) throw CompileError("Return is untyped"); FunctionType* fT = FunctionType::get(retT, cprot, false); Function* f = Function::Create(fT, linkage, name, cenv.module); if (f->getName() != name) { f->eraseFromParent(); throw CompileError("Function redefined"); } // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); for (size_t i = 0; i != prot.size(); ++a, ++i) a->setName(prot.at(i)->str()); return f; } Value* ASTSymbol::compile(CEnv& cenv) { AST** c = cenv.code.ref(this); if (!c) throw SyntaxError((string("Undefined symbol: ") + cppstr).c_str()); return cenv.vals.def(this, cenv.compile(*c)); } void ASTClosure::lift(CEnv& cenv) { if (cenv.tenv.type(at(2))->var) throw CompileError("Closure with untyped body lifted"); for (size_t i = 0; i < prot->size(); ++i) if (cenv.tenv.type(prot->at(i))->var) throw CompileError("Closure with untyped parameter lifted"); assert(!func); cenv.push(); // Write function declaration Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(at(2))->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); // Bind argument values in CEnv vector args; const_iterator p = prot->begin(); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) cenv.vals.def(dynamic_cast(*p), &*a); // Write function body try { cenv.precompile(this, f); // Define our value first for recursion Value* retVal = cenv.compile(at(2)); cenv.builder.CreateRet(retVal); // Finish function cenv.optimise(*f); func = f; } catch (exception e) { f->eraseFromParent(); // Error reading body, remove function throw e; } assert(func); cenv.pop(); } Value* ASTClosure::compile(CEnv& cenv) { assert(func); return func; // Function was already compiled in the lifting pass } void ASTCall::lift(CEnv& cenv) { ASTClosure* c = dynamic_cast(at(0)); if (!c) { AST** val = cenv.code.ref(at(0)); c = (val) ? dynamic_cast(*val) : c; } // Lift arguments for (size_t i = 1; i < size(); ++i) at(i)->lift(cenv); if (!c) return; // Extend environment with bound and typed parameters cenv.push(); if (c->prot->size() != size() - 1) throw CompileError("Call to closure with mismatched arguments"); for (size_t i = 1; i < size(); ++i) cenv.code.def(c->prot->at(i-1), at(i)); at(0)->lift(cenv); // Lift called closure cenv.pop(); // Restore environment } Value* ASTCall::compile(CEnv& cenv) { ASTClosure* c = dynamic_cast(at(0)); if (!c) { AST** val = cenv.code.ref(at(0)); c = (val) ? dynamic_cast(*val) : c; } assert(c); Function* f = dynamic_cast(cenv.compile(c)); if (!f) throw CompileError("Callee failed to compile"); vector params(size() - 1); for (size_t i = 1; i < size(); ++i) params[i-1] = cenv.compile(at(i)); return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } void ASTDefinition::lift(CEnv& cenv) { cenv.code.def((ASTSymbol*)at(1), at(2)); // Define first for recursion at(2)->lift(cenv); } Value* ASTDefinition::compile(CEnv& cenv) { return cenv.compile(at(2)); } Value* ASTIf::compile(CEnv& cenv) { typedef vector< pair > Branches; Function* parent = cenv.builder.GetInsertBlock()->getParent(); BasicBlock* mergeBB = BasicBlock::Create("endif"); BasicBlock* nextBB = NULL; Branches branches; ostringstream ss; for (size_t i = 1; i < size() - 1; i += 2) { Value* condV = cenv.compile(at(i)); ss.str(""); ss << "then" << ((i + 1) / 2); BasicBlock* thenBB = BasicBlock::Create(ss.str()); ss.str(""); ss << "else" << ((i + 1) / 2); nextBB = BasicBlock::Create(ss.str()); cenv.builder.CreateCondBr(condV, thenBB, nextBB); // Emit then block for this condition parent->getBasicBlockList().push_back(thenBB); cenv.builder.SetInsertPoint(thenBB); Value* thenV = cenv.compile(at(i + 1)); cenv.builder.CreateBr(mergeBB); branches.push_back(make_pair(thenV, cenv.builder.GetInsertBlock())); parent->getBasicBlockList().push_back(nextBB); cenv.builder.SetInsertPoint(nextBB); } // Emit else block cenv.builder.SetInsertPoint(nextBB); Value* elseV = cenv.compile(at(size() - 1)); cenv.builder.CreateBr(mergeBB); branches.push_back(make_pair(elseV, cenv.builder.GetInsertBlock())); // Emit merge block (Phi node) parent->getBasicBlockList().push_back(mergeBB); cenv.builder.SetInsertPoint(mergeBB); PHINode* pn = cenv.builder.CreatePHI(cenv.tenv.type(this)->ctype, "ifval"); for (Branches::iterator i = branches.begin(); i != branches.end(); ++i) pn->addIncoming(i->first, i->second); return pn; } Value* ASTPrimitive::compile(CEnv& cenv) { if (size() < 3) throw SyntaxError("Too few arguments"); Value* a = cenv.compile(at(1)); Value* b = cenv.compile(at(2)); if (OP_IS_A(op, Instruction::BinaryOps)) { const Instruction::BinaryOps bo = (Instruction::BinaryOps)op; if (size() == 2) return cenv.compile(at(1)); Value* val = cenv.builder.CreateBinOp(bo, a, b); for (size_t i = 3; i < size(); ++i) val = cenv.builder.CreateBinOp(bo, val, cenv.compile(at(i))); return val; } else if (op == Instruction::ICmp) { bool isInt = cenv.tenv.type(at(1))->str() == "(Int)"; if (isInt) { return cenv.builder.CreateICmp((CmpInst::Predicate)arg, a, b); } else { // Translate to floating point operation switch (arg) { case CmpInst::ICMP_EQ: arg = CmpInst::FCMP_OEQ; break; case CmpInst::ICMP_NE: arg = CmpInst::FCMP_ONE; break; case CmpInst::ICMP_SGT: arg = CmpInst::FCMP_OGT; break; case CmpInst::ICMP_SGE: arg = CmpInst::FCMP_OGE; break; case CmpInst::ICMP_SLT: arg = CmpInst::FCMP_OLT; break; case CmpInst::ICMP_SLE: arg = CmpInst::FCMP_OLE; break; default: throw CompileError("Unknown primitive"); } return cenv.builder.CreateFCmp((CmpInst::Predicate)arg, a, b); } } throw CompileError("Unknown primitive"); } /*************************************************************************** * REPL * ***************************************************************************/ int main() { #define PRIM(O, A) PEnv::Parser(parsePrim, Op(Instruction:: O, A)) PEnv penv; penv.reg("fn", PEnv::Parser(parseFn, Op())); penv.reg("if", PEnv::Parser(parseIf, Op())); penv.reg("def", PEnv::Parser(parseDef, Op())); penv.reg("+", PRIM(Add, 0)); penv.reg("-", PRIM(Sub, 0)); penv.reg("*", PRIM(Mul, 0)); penv.reg("/", PRIM(FDiv, 0)); penv.reg("%", PRIM(FRem, 0)); penv.reg("&", PRIM(And, 0)); penv.reg("|", PRIM(Or, 0)); penv.reg("^", PRIM(Xor, 0)); penv.reg("=", PRIM(ICmp, CmpInst::ICMP_EQ)); penv.reg("!=", PRIM(ICmp, CmpInst::ICMP_NE)); penv.reg(">", PRIM(ICmp, CmpInst::ICMP_SGT)); penv.reg(">=", PRIM(ICmp, CmpInst::ICMP_SGE)); penv.reg("<", PRIM(ICmp, CmpInst::ICMP_SLT)); penv.reg("<=", PRIM(ICmp, CmpInst::ICMP_SLE)); Module* module = new Module("repl"); ExecutionEngine* engine = ExecutionEngine::create(module); CEnv cenv(penv, module, engine->getTargetData()); cenv.tenv.name("Bool", Type::Int1Ty); cenv.tenv.name("Int", Type::Int32Ty); cenv.tenv.name("Float", Type::FloatTy); cenv.code.def(penv.sym("true"), new ASTLiteral(true)); cenv.code.def(penv.sym("false"), new ASTLiteral(false)); while (1) { std::cout << "() "; std::cout.flush(); SExp exp = readExpression(std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; try { AST* body = parseExpression(penv, exp); // Parse input body->constrain(cenv.tenv); // Constrain types cenv.tenv.solve(); // Solve and apply type constraints AType* bodyT = cenv.tenv.type(body); if (!bodyT) throw TypeError("REPL call to untyped body"); if (bodyT->var) throw TypeError("REPL call to variable typed body"); body->lift(cenv); if (bodyT->ctype) { // Create anonymous function to insert code into. ASTTuple* prot = new ASTTuple(); Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); try { Value* retVal = cenv.compile(body); cenv.builder.CreateRet(retVal); // Finish function cenv.optimise(*f); } catch (SyntaxError e) { f->eraseFromParent(); // Error reading body, remove function throw e; } void* fp = engine->getPointerToFunction(f); if (bodyT->ctype == Type::Int32Ty) std::cout << "; " << ((int32_t (*)())fp)(); else if (bodyT->ctype == Type::FloatTy) std::cout << "; " << ((float (*)())fp)(); else if (bodyT->ctype == Type::Int1Ty) std::cout << "; " << ((bool (*)())fp)(); } else { Value* val = cenv.compile(body); std::cout << "; " << val; } std::cout << " : " << cenv.tenv.type(body)->str() << endl; } catch (Error e) { std::cerr << "Error: " << e.what() << endl; } } std::cout << endl << "Generated code:" << endl; module->dump(); return 0; }