/* Tuplr: A minimalist programming language * Copyright (C) 2008-2009 David Robillard * * Tuplr is free software: you can redistribute it and/or modify it under * the terms of the GNU Affero General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your * option) any later version. * * Tuplr is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General * Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Tuplr. If not, see . */ #include #include #include #include #include #include #include #include #include #include #include "llvm/Analysis/Verifier.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" #define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i) using namespace llvm; using namespace std; using boost::format; struct Cursor { Cursor(const string& n="", unsigned l=1, unsigned c=0) : name(n), line(l), col(c) {} string str() const { return (format("%1%:%2%:%3%") % name % line % col).str(); } string name; unsigned line; unsigned col; }; struct Error { Error(const string& m, Cursor c=Cursor()) : msg(m), loc(c) {} const string what() const throw() { return loc.str() + ": error: " + msg; } string msg; Cursor loc; }; template struct Exp { // ::= Atom | (Exp*) Exp(Cursor c) : loc(c), type(LIST) {} Exp(Cursor c, const A& a) : loc(c), type(ATOM), atom(a) {} Cursor loc; enum { ATOM, LIST } type; typedef std::vector< Exp > List; A atom; List list; }; /*************************************************************************** * S-Expression Lexer :: text -> S-Expressions (SExp) * ***************************************************************************/ typedef Exp SExp; static SExp readExpression(Cursor& loc, std::istream& in) { #define PUSH(s, t) { if (t != "") { s.top().list.push_back(SExp(loc, t)); t = ""; } } #define YIELD(s, t) { if (s.empty()) return SExp(loc, t); else PUSH(s, t) } stack stk; string tok; while (char ch = in.get()) { ++loc.col; switch (ch) { case EOF: if (!stk.empty()) throw Error("unexpected end of file", loc); return SExp(loc); case '\n': ++loc.line; loc.col = 0; case ' ': case '\t': if (tok != "") YIELD(stk, tok); break; case '"': do { tok.push_back(ch); ++loc.col; } while ((ch = in.get()) != '"'); YIELD(stk, tok + '"'); break; case '(': stk.push(SExp(loc)); break; case ')': switch (stk.size()) { case 0: throw Error("unexpected `)'", loc); case 1: PUSH(stk, tok); return stk.top(); default: PUSH(stk, tok); SExp l = stk.top(); stk.pop(); stk.top().list.push_back(l); } break; default: tok += ch; } } switch (stk.size()) { case 0: return SExp(loc, tok); case 1: return stk.top(); default: throw Error("missing `)'", loc); } return SExp(loc); } /*************************************************************************** * Abstract Syntax Tree * ***************************************************************************/ struct TEnv; ///< Type-Time Environment struct CEnv; ///< Compile-Time Environment /// Base class for all AST nodes struct AST { virtual ~AST() {} virtual bool contains(AST* child) const { return false; } virtual bool operator!=(const AST& o) const { return !operator==(o); } virtual bool operator==(const AST& o) const = 0; virtual string str() const = 0; virtual void constrain(TEnv& tenv) const {} virtual void lift(CEnv& cenv) {} virtual Value* compile(CEnv& cenv) = 0; }; /// Literal template struct ASTLiteral : public AST { ASTLiteral(VT v) : val(v) {} bool operator==(const AST& rhs) const { const ASTLiteral* r = dynamic_cast*>(&rhs); return r && val == r->val; } string str() const { return (format("%1%") % val).str(); } void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); const VT val; }; /// Symbol, e.g. "a" struct ASTSymbol : public AST { ASTSymbol(const string& s) : cppstr(s) {} bool operator==(const AST& rhs) const { return this == &rhs; } string str() const { return cppstr; } Value* compile(CEnv& cenv); private: const string cppstr; }; /// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)" struct ASTTuple : public AST, public vector { ASTTuple(const vector& t=vector()) : vector(t) {} ASTTuple(size_t size) : vector(size) {} ASTTuple(AST* ast, ...) { push_back(ast); va_list args; va_start(args, ast); for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*)) push_back(a); va_end(args); } string str() const { string ret = "("; for (size_t i = 0; i != size(); ++i) ret += at(i)->str() + ((i != size() - 1) ? " " : ""); return ret + ")"; } bool operator==(const AST& rhs) const { const ASTTuple* rt = dynamic_cast(&rhs); if (!rt) return false; if (rt->size() != size()) return false; const_iterator l = begin(); FOREACH(const_iterator, r, *rt) { AST* mine = *l++; AST* other = *r; if (!(*mine == *other)) return false; } return true; } void lift(CEnv& cenv) { FOREACH(iterator, t, *this) (*t)->lift(cenv); } bool isForm(const string& f) { return !empty() && at(0)->str() == f; } bool contains(AST* child) const; void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv) { throw Error("tuple compiled"); } }; /// Type Expression, e.g. "(Int)" or "(Fn ((Int)) (Float))" struct AType : public ASTTuple { AType(const ASTTuple& t) : ASTTuple(t), var(false), ctype(0) {} AType(unsigned i) : var(true), id(i), ctype(0) {} AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) { push_back(n); } string str() const { if (var) { return (format("?%1%") % id).str(); } else { return ASTTuple::str(); } } void constrain(TEnv& tenv) const {} Value* compile(CEnv& cenv) { return NULL; } bool concrete() const { if (var) return false; FOREACH(const_iterator, t, *this) { AType* kid = dynamic_cast(*t); if (kid && !kid->concrete()) return false; } return true; } bool operator==(const AST& rhs) const { const AType* rt = dynamic_cast(&rhs); if (!rt) return false; else if (var && rt->var) return id == rt->id; else if (!var && !rt->var) return ASTTuple::operator==(rhs); return false; } const Type* type() { if (at(0)->str() == "Pair") { vector types; for (size_t i = 1; i < size(); ++i) { assert(dynamic_cast(at(i))); types.push_back(((AType*)at(i))->type()); } return PointerType::get(StructType::get(types, false), 0); } else { return ctype; } }; bool var; unsigned id; private: const Type* ctype; }; /// Possibly several lifted LLVM functions for a single Tuplr function struct Funcs : public list< pair > { Function* find(AType* type) const { for (const_iterator f = begin(); f != end(); ++f) if (*f->first == *type) return f->second; return NULL; } void insert(AType* type, Function* func) { push_back(make_pair(type, func)); } }; /// Closure (first-class function with captured lexical bindings) struct ASTClosure : public ASTTuple { ASTClosure(ASTTuple* p, AST* b, const string& n="") : ASTTuple(0, p, b), prot(p), func(0), name(n) {} bool operator==(const AST& rhs) const { return this == &rhs; } string str() const { return (format("%1%") % this).str(); } void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); ASTTuple* const prot; private: Function* func; string name; }; /// Function call/application, e.g. "(func arg1 arg2)" struct ASTCall : public ASTTuple { ASTCall(const SExp& e, const ASTTuple& t) : ASTTuple(t), exp(e) {} void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); const SExp& exp; }; /// Definition special form, e.g. "(def x 2)" struct ASTDefinition : public ASTCall { ASTDefinition(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { ASTIf(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { ASTPrimitive(const SExp& e, const ASTTuple& t, int o, int a=0) : ASTCall(e, t), op(o), arg(a) {} void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); unsigned op; unsigned arg; }; /// Cons special form, e.g. "(cons 1 2)" struct ASTConsCall : public ASTCall { ASTConsCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} AType* functionType(CEnv& cenv); void constrain(TEnv& tenv) const; void lift(CEnv& cenv); Value* compile(CEnv& cenv); static Funcs funcs; }; Funcs ASTConsCall::funcs; /// Car special form, e.g. "(car p)" struct ASTCarCall : public ASTCall { ASTCarCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); }; /// Cdr special form, e.g. "(cdr p)" struct ASTCdrCall : public ASTCall { ASTCdrCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); }; /*************************************************************************** * Parser - S-Expressions (SExp) -> AST Nodes (AST) * ***************************************************************************/ /// LLVM Operation struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; }; typedef Op UD; // User Data argument for parse functions // Parse Time Environment (symbol table) struct PEnv : private map { typedef AST* (*PF)(PEnv&, const SExp&, UD); // Parse Function struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; map parsers; void reg(const string& s, const Parser& p) { parsers.insert(make_pair(sym(s)->str(), p)); } const Parser* parser(const string& s) const { map::const_iterator i = parsers.find(s); return (i != parsers.end()) ? &i->second : NULL; } ASTSymbol* sym(const string& s) { const const_iterator i = find(s); return ((i != end()) ? i->second : insert(make_pair(s, new ASTSymbol(s))).first->second); } }; /// The fundamental parser method static AST* parseExpression(PEnv& penv, const SExp& exp); static ASTTuple pmap(PEnv& penv, const SExp::List& l) { ASTTuple ret(l.size()); size_t n = 0; FOREACH(SExp::List::const_iterator, i, l) ret[n++] = parseExpression(penv, *i); return ret; } static AST* parseExpression(PEnv& penv, const SExp& exp) { if (exp.type == SExp::LIST) { if (exp.list.empty()) throw Error("call to empty list", exp.loc); if (exp.list.front().type == SExp::ATOM) { const PEnv::Parser* handler = penv.parser(exp.list.front().atom); if (handler) // Dispatch to parse function return handler->pf(penv, exp, handler->ud); } return new ASTCall(exp, pmap(penv, exp.list)); // Parse as regular call } else if (isdigit(exp.atom[0])) { if (exp.atom.find('.') == string::npos) return new ASTLiteral(strtol(exp.atom.c_str(), NULL, 10)); else return new ASTLiteral(strtod(exp.atom.c_str(), NULL)); } return penv.sym(exp.atom); } // Special forms static AST* parseIf(PEnv& penv, const SExp& exp, UD) { return new ASTIf(exp, pmap(penv, exp.list)); } static AST* parseDef(PEnv& penv, const SExp& exp, UD) { return new ASTDefinition(exp, pmap(penv, exp.list)); } static AST* parsePrim(PEnv& penv, const SExp& exp, UD data) { return new ASTPrimitive(exp, pmap(penv, exp.list), data.op, data.arg); } static AST* parseFn(PEnv& penv, const SExp& exp, UD) { SExp::List::const_iterator a = exp.list.begin(); ++a; return new ASTClosure( new ASTTuple(pmap(penv, (*a++).list)), parseExpression(penv, *a++)); } static AST* parseCons(PEnv& penv, const SExp& exp, UD) { return new ASTConsCall(exp, pmap(penv, exp.list)); } static AST* parseCar(PEnv& penv, const SExp& exp, UD) { return new ASTCarCall(exp, pmap(penv, exp.list)); } static AST* parseCdr(PEnv& penv, const SExp& exp, UD) { return new ASTCdrCall(exp, pmap(penv, exp.list)); } /*************************************************************************** * Generic Lexical Environment * ***************************************************************************/ template struct Env : public list< map > { typedef map Frame; Env() : list(1) {} void push_front() { list::push_front(Frame()); } const V& def(const K& k, const V& v) { typename Frame::iterator existing = this->front().find(k); if (existing != this->front().end() && existing->second != v) throw Error("redefinition"); return (this->front()[k] = v); } V* ref(const K& name) { typename Frame::iterator s; for (typename Env::iterator i = this->begin(); i != this->end(); ++i) if ((s = i ->find(name)) != i->end()) return &s->second; return 0; } }; /*************************************************************************** * Typing * ***************************************************************************/ struct TSubst : public map { TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); } }; /// Type-Time Environment struct TEnv { TEnv(PEnv& p) : penv(p), varID(1) {} typedef map Types; typedef list< pair > Constraints; AType* var() { return new AType(varID++); } AType* type(const AST* ast) { Types::iterator t = types.find(ast); return (t != types.end()) ? t->second : (types[ast] = var()); } AType* named(const string& name) const { Types::const_iterator i = namedTypes.find(penv.sym(name)); if (i == namedTypes.end()) throw Error("Unknown named type"); return i->second; } void name(const string& name, const Type* type) { ASTSymbol* sym = penv.sym(name); namedTypes[sym] = new AType(penv.sym(name), type); } void constrain(const AST* o, AType* t) { constraints.push_back(make_pair(type(o), t)); } void solve() { apply(unify(constraints)); } void apply(const TSubst& substs); static TSubst unify(const Constraints& c); PEnv& penv; Types types; Types namedTypes; Constraints constraints; unsigned varID; }; #define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) void ASTTuple::constrain(TEnv& tenv) const { AType* t = new AType(ASTTuple()); FOREACH(const_iterator, p, *this) { (*p)->constrain(tenv); t->push_back(tenv.type(*p)); } tenv.constrain(tenv.type(this), t); } void ASTClosure::constrain(TEnv& tenv) const { prot->constrain(tenv); at(2)->constrain(tenv); AType* bodyT = tenv.type(at(2)); tenv.constrain(this, new AType( ASTTuple(tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0))); } void ASTCall::constrain(TEnv& tenv) const { FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* retT = tenv.type(this); tenv.constrain(at(0), new AType(ASTTuple( tenv.penv.sym("Fn"), tenv.var(), retT, NULL))); } void ASTDefinition::constrain(TEnv& tenv) const { if (size() != 3) throw Error("`def' requires exactly 2 arguments", exp.loc); if (!dynamic_cast(at(1))) throw Error("`def' name is not a symbol", exp.loc); FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* tvar = tenv.type(this); tenv.constrain(at(1), tvar); tenv.constrain(at(2), tvar); } void ASTIf::constrain(TEnv& tenv) const { FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); AType* tvar = tenv.type(this); tenv.constrain(at(1), tenv.named("Bool")); tenv.constrain(at(2), tvar); tenv.constrain(at(3), tvar); } void ASTPrimitive::constrain(TEnv& tenv) const { FOREACH(const_iterator, p, *this) (*p)->constrain(tenv); if (OP_IS_A(op, Instruction::BinaryOps)) { if (size() <= 2) throw Error((format("`%1%' requires at least 2 arguments") % at(0)->str()).str(), exp.loc); AType* tvar = tenv.type(this); for (size_t i = 1; i < size(); ++i) tenv.constrain(at(i), tvar); } else if (op == Instruction::ICmp) { if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % at(0)->str()).str(), exp.loc); tenv.constrain(at(1), tenv.type(at(2))); tenv.constrain(this, tenv.named("Bool")); } else { throw Error((format("unknown primitive `%1%'") % at(0)->str()).str(), exp.loc); } } void ASTConsCall::constrain(TEnv& tenv) const { AType* t = new AType(ASTTuple(tenv.penv.sym("Pair"), 0)); for (size_t i = 1; i < size(); ++i) { at(i)->constrain(tenv); t->push_back(tenv.type(at(i))); } tenv.constrain(this, t); } void ASTCarCall::constrain(TEnv& tenv) const { at(1)->constrain(tenv); AType* ct = tenv.var(); AType* tt = new AType(ASTTuple(tenv.penv.sym("Pair"), ct, tenv.var(), 0)); tenv.constrain(at(1), tt); tenv.constrain(this, ct); } void ASTCdrCall::constrain(TEnv& tenv) const { at(1)->constrain(tenv); AType* ct = tenv.var(); AType* tt = new AType(ASTTuple(tenv.penv.sym("Pair"), tenv.var(), ct, 0)); tenv.constrain(at(1), tt); tenv.constrain(this, ct); } static void substitute(ASTTuple* tup, AST* from, AST* to) { if (!tup) return; for (size_t i = 0; i < tup->size(); ++i) if (*tup->at(i) == *from) tup->at(i) = to; else substitute(dynamic_cast(tup->at(i)), from, to); } bool ASTTuple::contains(AST* child) const { if (*this == *child) return true; FOREACH(const_iterator, p, *this) if (**p == *child || (*p)->contains(child)) return true; return false; } TSubst compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1 { TSubst r; for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) { TSubst::const_iterator d = delta.find(g->second); r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second)); } for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) { if (gamma.find(d->first) == gamma.end()) r.insert(*d); } return r; } void substConstraints(TEnv::Constraints& constraints, AType* s, AType* t) { for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) { TEnv::Constraints::iterator next = c; ++next; if (*c->first == *s) c->first = t; if (*c->second == *s) c->second = t; substitute(c->first, s, t); substitute(c->second, s, t); c = next; } } TSubst TEnv::unify(const Constraints& constraints) // TAPL 22.4 { if (constraints.empty()) return TSubst(); AType* s = constraints.begin()->first; AType* t = constraints.begin()->second; Constraints cp = constraints; cp.erase(cp.begin()); if (*s == *t) { return unify(cp); } else if (s->var && !t->contains(s)) { substConstraints(cp, s, t); return compose(unify(cp), TSubst(s, t)); } else if (t->var && !s->contains(t)) { substConstraints(cp, t, s); return compose(unify(cp), TSubst(t, s)); } else if ((s->isForm("Fn") && t->isForm("Fn")) || (s->isForm("Pair") && t->isForm("Pair"))) { AType* s1 = dynamic_cast(s->at(1)); AType* t1 = dynamic_cast(t->at(1)); AType* s2 = dynamic_cast(s->at(2)); AType* t2 = dynamic_cast(t->at(2)); assert(s1 && t1 && s2 && t2); cp.push_back(make_pair(s1, t1)); cp.push_back(make_pair(s2, t2)); return unify(cp); } else { throw Error("Type unification failed"); } } void TEnv::apply(const TSubst& substs) { FOREACH(TSubst::const_iterator, s, substs) FOREACH(Types::iterator, t, types) if (*t->second == *s->first) t->second = s->second; } /*************************************************************************** * Code Generation * ***************************************************************************/ class PEnv; /// Compile-Time Environment struct CEnv { CEnv(PEnv& p, Module* m, const TargetData* target) : penv(p), tenv(p), module(m), emp(module), opt(&emp), symID(0) { // Set up the optimizer pipeline: opt.add(new TargetData(*target)); // Register target arch opt.add(createInstructionCombiningPass()); // Simple optimizations opt.add(createReassociatePass()); // Reassociate expressions opt.add(createGVNPass()); // Eliminate Common Subexpressions opt.add(createCFGSimplificationPass()); // Simplify control flow } string gensym(const char* s="_") { return (format("%1%%2%") % s % symID++).str(); } void push() { code.push_front(); vals.push_front(); } void pop() { code.pop_front(); vals.pop_front(); } Value* compile(AST* obj) { Value** v = vals.ref(obj); return (v) ? *v : vals.def(obj, obj->compile(*this)); } void precompile(AST* obj, Value* value) { assert(!vals.ref(obj)); vals.def(obj, value); } void optimise(Function& f) { return; verifyFunction(f); opt.run(f); } typedef Env Code; typedef Env Vals; PEnv& penv; TEnv tenv; IRBuilder<> builder; Module* module; ExistingModuleProvider emp; FunctionPassManager opt; unsigned symID; Code code; Vals vals; Function* alloc; }; #define LITERAL(CT, NAME, COMPILED) \ template<> Value* \ ASTLiteral::compile(CEnv& cenv) { return (COMPILED); } \ template<> void \ ASTLiteral::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } /// Literal template instantiations LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)); LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)); LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); static Function* compileFunction(CEnv& cenv, const std::string& name, const Type* retT, ASTTuple& prot, const vector argNames=vector()) { Function::LinkageTypes linkage = Function::ExternalLinkage; vector cprot; for (size_t i = 0; i < prot.size(); ++i) { AType* at = cenv.tenv.type(prot.at(i)); if (!at->type() || at->var) throw Error("function parameter is untyped"); cprot.push_back(at->type()); } if (!retT) throw Error("function return is untyped"); FunctionType* fT = FunctionType::get(retT, cprot, false); Function* f = Function::Create(fT, linkage, name, cenv.module); if (f->getName() != name) { f->eraseFromParent(); throw Error("function redefined"); } // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); if (!argNames.empty()) for (size_t i = 0; i != prot.size(); ++a, ++i) a->setName(argNames.at(i)); else for (size_t i = 0; i != prot.size(); ++a, ++i) a->setName(prot.at(i)->str()); return f; } Value* ASTSymbol::compile(CEnv& cenv) { AST** c = cenv.code.ref(this); if (!c) throw Error((string("Undefined symbol: ") + cppstr).c_str()); return cenv.compile(*c); } void ASTClosure::lift(CEnv& cenv) { if (cenv.tenv.type(at(2))->var) throw Error("Closure with untyped body lifted"); for (size_t i = 0; i < prot->size(); ++i) if (cenv.tenv.type(prot->at(i))->var) throw Error("Closure with untyped parameter lifted"); assert(!func); cenv.push(); // Write function declaration string name = this->name == "" ? cenv.gensym("_fn") : this->name; Function* f = compileFunction(cenv, name, cenv.tenv.type(at(2))->type(), *prot); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); // Bind argument values in CEnv vector args; const_iterator p = prot->begin(); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) cenv.vals.def(dynamic_cast(*p), &*a); // Write function body try { cenv.precompile(this, f); // Define our value first for recursion Value* retVal = cenv.compile(at(2)); cenv.builder.CreateRet(retVal); // Finish function cenv.optimise(*f); func = f; } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function throw e; } assert(func); cenv.pop(); } Value* ASTClosure::compile(CEnv& cenv) { assert(func); return func; // Function was already compiled in the lifting pass } void ASTCall::lift(CEnv& cenv) { ASTClosure* c = dynamic_cast(at(0)); if (!c) { AST** val = cenv.code.ref(at(0)); c = (val) ? dynamic_cast(*val) : c; } // Lift arguments for (size_t i = 1; i < size(); ++i) at(i)->lift(cenv); if (!c) return; // Extend environment with bound and typed parameters cenv.push(); if (c->prot->size() < size() - 1) throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), exp.loc); if (c->prot->size() > size() - 1) throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), exp.loc); for (size_t i = 1; i < size(); ++i) cenv.code.def(c->prot->at(i-1), at(i)); at(0)->lift(cenv); // Lift called closure cenv.pop(); // Restore environment } Value* ASTCall::compile(CEnv& cenv) { ASTClosure* c = dynamic_cast(at(0)); if (!c) { AST** val = cenv.code.ref(at(0)); c = (val) ? dynamic_cast(*val) : c; } assert(c); Function* f = dynamic_cast(cenv.compile(c)); if (!f) throw Error("callee failed to compile", exp.loc); vector params(size() - 1); for (size_t i = 1; i < size(); ++i) params[i-1] = cenv.compile(at(i)); return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } void ASTDefinition::lift(CEnv& cenv) { if (cenv.code.ref((ASTSymbol*)at(1))) throw Error(string("`") + at(1)->str() + "' redefined", exp.loc); cenv.code.def((ASTSymbol*)at(1), at(2)); // Define first for recursion at(2)->lift(cenv); } Value* ASTDefinition::compile(CEnv& cenv) { return cenv.compile(at(2)); } Value* ASTIf::compile(CEnv& cenv) { typedef vector< pair > Branches; Function* parent = cenv.builder.GetInsertBlock()->getParent(); BasicBlock* mergeBB = BasicBlock::Create("endif"); BasicBlock* nextBB = NULL; Branches branches; ostringstream ss; for (size_t i = 1; i < size() - 1; i += 2) { Value* condV = cenv.compile(at(i)); ss.str(""); ss << "then" << ((i + 1) / 2); BasicBlock* thenBB = BasicBlock::Create(ss.str()); ss.str(""); ss << "else" << ((i + 1) / 2); nextBB = BasicBlock::Create(ss.str()); cenv.builder.CreateCondBr(condV, thenBB, nextBB); // Emit then block for this condition parent->getBasicBlockList().push_back(thenBB); cenv.builder.SetInsertPoint(thenBB); Value* thenV = cenv.compile(at(i + 1)); cenv.builder.CreateBr(mergeBB); branches.push_back(make_pair(thenV, cenv.builder.GetInsertBlock())); parent->getBasicBlockList().push_back(nextBB); cenv.builder.SetInsertPoint(nextBB); } // Emit else block cenv.builder.SetInsertPoint(nextBB); Value* elseV = cenv.compile(at(size() - 1)); cenv.builder.CreateBr(mergeBB); branches.push_back(make_pair(elseV, cenv.builder.GetInsertBlock())); // Emit merge block (Phi node) parent->getBasicBlockList().push_back(mergeBB); cenv.builder.SetInsertPoint(mergeBB); PHINode* pn = cenv.builder.CreatePHI(cenv.tenv.type(this)->type(), "ifval"); for (Branches::iterator i = branches.begin(); i != branches.end(); ++i) pn->addIncoming(i->first, i->second); return pn; } Value* ASTPrimitive::compile(CEnv& cenv) { Value* a = cenv.compile(at(1)); Value* b = cenv.compile(at(2)); if (OP_IS_A(op, Instruction::BinaryOps)) { const Instruction::BinaryOps bo = (Instruction::BinaryOps)op; if (size() == 2) return cenv.compile(at(1)); Value* val = cenv.builder.CreateBinOp(bo, a, b); for (size_t i = 3; i < size(); ++i) val = cenv.builder.CreateBinOp(bo, val, cenv.compile(at(i))); return val; } else if (op == Instruction::ICmp) { bool isInt = cenv.tenv.type(at(1))->str() == "(Int)"; if (isInt) { return cenv.builder.CreateICmp((CmpInst::Predicate)arg, a, b); } else { // Translate to floating point operation switch (arg) { case CmpInst::ICMP_EQ: arg = CmpInst::FCMP_OEQ; break; case CmpInst::ICMP_NE: arg = CmpInst::FCMP_ONE; break; case CmpInst::ICMP_SGT: arg = CmpInst::FCMP_OGT; break; case CmpInst::ICMP_SGE: arg = CmpInst::FCMP_OGE; break; case CmpInst::ICMP_SLT: arg = CmpInst::FCMP_OLT; break; case CmpInst::ICMP_SLE: arg = CmpInst::FCMP_OLE; break; default: throw Error("Unknown primitive", exp.loc); } return cenv.builder.CreateFCmp((CmpInst::Predicate)arg, a, b); } } throw Error("Unknown primitive", exp.loc); } AType* ASTConsCall::functionType(CEnv& cenv) { ASTTuple* protTypes = new ASTTuple(cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), NULL); AType* cellType = new AType(ASTTuple(cenv.penv.sym("Pair"), cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), NULL)); return new AType(ASTTuple(cenv.penv.sym("Fn"), protTypes, cellType, NULL)); } void ASTConsCall::lift(CEnv& cenv) { AType* funcType = functionType(cenv); if (funcs.find(functionType(cenv))) return; ASTCall::lift(cenv); ASTTuple* prot = new ASTTuple(at(1), at(2), NULL); vector types; size_t sz = 0; for (size_t i = 1; i < size(); ++i) { const Type* t = cenv.tenv.type(at(i))->type(); types.push_back(t); sz += t->getPrimitiveSizeInBits(); } sz = (sz % 8 == 0) ? sz / 8 : sz / 8 + 1; StructType* sT = StructType::get(types, false); Type* pT = PointerType::get(sT, 0); // Write function declaration vector argNames; argNames.push_back("car"); argNames.push_back("cdr"); Function* func = compileFunction(cenv, cenv.gensym("cons"), pT, *prot, argNames); BasicBlock* bb = BasicBlock::Create("entry", func); cenv.builder.SetInsertPoint(bb); Value* mem = cenv.builder.CreateCall(cenv.alloc, ConstantInt::get(Type::Int32Ty, sz), "mem"); Value* cell = cenv.builder.CreateBitCast(mem, pT, "cell"); Value* s = cenv.builder.CreateGEP(cell, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* carP = cenv.builder.CreateStructGEP(s, 0, "car"); Value* cdrP = cenv.builder.CreateStructGEP(s, 1, "cdr"); Function::arg_iterator ai = func->arg_begin(); Value& carArg = *ai++; Value& cdrArg = *ai++; cenv.builder.CreateStore(&carArg, carP); cenv.builder.CreateStore(&cdrArg, cdrP); cenv.builder.CreateRet(cell); cenv.optimise(*func); funcs.insert(funcType, func); } Value* ASTConsCall::compile(CEnv& cenv) { vector params(size() - 1); for (size_t i = 1; i < size(); ++i) params[i-1] = cenv.compile(at(i)); return cenv.builder.CreateCall(funcs.find(functionType(cenv)), params.begin(), params.end()); } Value* ASTCarCall::compile(CEnv& cenv) { AST** arg = cenv.code.ref(at(1)); Value* sP = arg ? (*arg)->compile(cenv) : at(1)->compile(cenv); Value* s = cenv.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* carP = cenv.builder.CreateStructGEP(s, 0, "car"); return cenv.builder.CreateLoad(carP); } Value* ASTCdrCall::compile(CEnv& cenv) { AST** arg = cenv.code.ref(at(1)); Value* sP = arg ? (*arg)->compile(cenv) : at(1)->compile(cenv); Value* s = cenv.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair"); Value* cdrP = cenv.builder.CreateStructGEP(s, 1, "cdr"); return cenv.builder.CreateLoad(cdrP); } /*************************************************************************** * EVAL/REPL/MAIN * ***************************************************************************/ std::string call(AType* retT, void* fp) { std::stringstream ss; if (retT->type() == Type::Int32Ty) ss << ((int32_t (*)())fp)(); else if (retT->type() == Type::FloatTy) ss << ((float (*)())fp)(); else if (retT->type() == Type::Int1Ty) ss << ((bool (*)())fp)(); else ss << ((void* (*)())fp)(); return ss.str(); } int eval(CEnv& cenv, ExecutionEngine* engine, const string& name, istream& is) { AST* result = NULL; AType* resultType = NULL; list< pair > exprs; Cursor cursor(name); try { while (true) { SExp exp = readExpression(cursor, is); if (exp.type == SExp::LIST && exp.list.empty()) break; result = parseExpression(cenv.penv, exp); // Parse input result->constrain(cenv.tenv); // Constrain types cenv.tenv.solve(); // Solve and apply type constraints resultType = cenv.tenv.type(result); result->lift(cenv); // Lift functions exprs.push_back(make_pair(exp, result)); } if (!resultType || resultType->var) throw Error("body is undefined/untyped", cursor); } catch (Error& e) { std::cerr << e.what() << endl; return 1; } // Create function for top-level of program ASTTuple prot; const Type* ctype = resultType->type(); Function* f = compileFunction(cenv, cenv.gensym("input"), ctype, prot); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); // Compile all expressions into it Value* val = NULL; for (list< pair >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) val = cenv.compile(i->second); // Finish function cenv.builder.CreateRet(val); cenv.optimise(*f); string resultStr = call(resultType, engine->getPointerToFunction(f)); std::cout << resultStr << " : " << resultType->str() << endl; return 0; } int repl(CEnv& cenv, ExecutionEngine* engine) { while (1) { std::cout << "() "; std::cout.flush(); Cursor cursor("(stdin)"); SExp exp = readExpression(cursor, std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; try { AST* body = parseExpression(cenv.penv, exp); // Parse input body->constrain(cenv.tenv); // Constrain types cenv.tenv.solve(); // Solve and apply type constraints AType* bodyT = cenv.tenv.type(body); if (!bodyT) throw Error("call to untyped body", cursor); if (bodyT->var) throw Error("call to variable typed body", cursor); body->lift(cenv); if (bodyT->type()) { // Create anonymous function to insert code into. ASTTuple* prot = new ASTTuple(); Function* f = compileFunction(cenv, cenv.gensym("_repl"), bodyT->type(), *prot); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); try { Value* retVal = cenv.compile(body); cenv.builder.CreateRet(retVal); // Finish function cenv.optimise(*f); } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function throw e; } std::cout << call(bodyT, engine->getPointerToFunction(f)); } else { Value* val = cenv.compile(body); std::cout << "; " << val; } std::cout << " : " << cenv.tenv.type(body)->str() << endl; } catch (Error& e) { std::cerr << e.what() << endl; } } return 0; } int main(int argc, char** argv) { #define PRIM(O, A) PEnv::Parser(parsePrim, Op(Instruction:: O, A)) PEnv penv; penv.reg("fn", PEnv::Parser(parseFn, Op())); penv.reg("if", PEnv::Parser(parseIf, Op())); penv.reg("def", PEnv::Parser(parseDef, Op())); penv.reg("cons", PEnv::Parser(parseCons, Op())); penv.reg("car", PEnv::Parser(parseCar, Op())); penv.reg("cdr", PEnv::Parser(parseCdr, Op())); penv.reg("+", PRIM(Add, 0)); penv.reg("-", PRIM(Sub, 0)); penv.reg("*", PRIM(Mul, 0)); penv.reg("/", PRIM(FDiv, 0)); penv.reg("%", PRIM(FRem, 0)); penv.reg("&", PRIM(And, 0)); penv.reg("|", PRIM(Or, 0)); penv.reg("^", PRIM(Xor, 0)); penv.reg("=", PRIM(ICmp, CmpInst::ICMP_EQ)); penv.reg("!=", PRIM(ICmp, CmpInst::ICMP_NE)); penv.reg(">", PRIM(ICmp, CmpInst::ICMP_SGT)); penv.reg(">=", PRIM(ICmp, CmpInst::ICMP_SGE)); penv.reg("<", PRIM(ICmp, CmpInst::ICMP_SLT)); penv.reg("<=", PRIM(ICmp, CmpInst::ICMP_SLE)); Module* module = new Module("interactive"); ExecutionEngine* engine = ExecutionEngine::create(module); CEnv cenv(penv, module, engine->getTargetData()); cenv.tenv.name("Bool", Type::Int1Ty); cenv.tenv.name("Int", Type::Int32Ty); cenv.tenv.name("Float", Type::FloatTy); cenv.code.def(penv.sym("true"), new ASTLiteral(true)); cenv.code.def(penv.sym("false"), new ASTLiteral(false)); // Host provided allocation primitive prototypes std::vector argsT(1, Type::Int32Ty); FunctionType* funcT = FunctionType::get(PointerType::get(Type::VoidTy, 0), argsT, false); cenv.alloc = Function::Create(funcT, Function::ExternalLinkage, "malloc", module); int ret; if (argc > 2 && !strncmp(argv[1], "-e", 3)) { std::istringstream is(argv[2]); ret = eval(cenv, engine, "(command line)", is); } else if (argc > 2 && !strncmp(argv[1], "-f", 3)) { std::ifstream is(argv[2]); ret = eval(cenv, engine, argv[2], is); is.close(); } else { ret = repl(cenv, engine); } //std::cout << endl << "*** Generated Code ***" << endl; //cenv.module->dump(); return ret; }