aboutsummaryrefslogtreecommitdiffstats
path: root/tuplr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tuplr.cpp')
-rw-r--r--tuplr.cpp981
1 files changed, 981 insertions, 0 deletions
diff --git a/tuplr.cpp b/tuplr.cpp
new file mode 100644
index 0000000..ada3073
--- /dev/null
+++ b/tuplr.cpp
@@ -0,0 +1,981 @@
+/* Tuplr: A minimalist programming language
+ * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
+ *
+ * Tuplr is free software: you can redistribute it and/or modify it under
+ * the terms of the GNU Affero General Public License as published by the
+ * Free Software Foundation, either version 3 of the License, or (at your
+ * option) any later version.
+ *
+ * Tuplr is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
+ * Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Tuplr. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <stdarg.h>
+#include <iostream>
+#include <list>
+#include <map>
+#include <sstream>
+#include <stack>
+#include <string>
+#include <vector>
+#include "llvm/Analysis/Verifier.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/Instructions.h"
+#include "llvm/Module.h"
+#include "llvm/ModuleProvider.h"
+#include "llvm/PassManager.h"
+#include "llvm/Support/IRBuilder.h"
+#include "llvm/Target/TargetData.h"
+#include "llvm/Transforms/Scalar.h"
+
+#define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i)
+
+using namespace llvm;
+using namespace std;
+
+struct Error : public std::exception {
+ Error(const char* m) : msg(m) {}
+ const char* what() const throw() { return msg; }
+ const char* msg;
+};
+
+template<typename A>
+struct Exp { // ::= Atom | (Exp*)
+ Exp() : type(LIST) {}
+ Exp(const A& a) : type(ATOM), atom(a) {}
+ enum { ATOM, LIST } type;
+ typedef std::vector< Exp<A> > List;
+ A atom;
+ List list;
+};
+
+
+/***************************************************************************
+ * S-Expression Lexer :: text -> S-Expressions (SExp) *
+ ***************************************************************************/
+
+struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} };
+typedef Exp<string> SExp;
+
+static SExp
+readExpression(std::istream& in)
+{
+#define PUSH(s, t) { if (t != "") { s.top().list.push_back(t); t = ""; } }
+#define YIELD(s, t) { if (s.empty()) return t; else PUSH(s, t) }
+ stack<SExp> stk;
+ string tok;
+ while (char ch = in.get()) {
+ switch (ch) {
+ case EOF:
+ return SExp();
+ case ' ': case '\t': case '\n':
+ if (tok != "") YIELD(stk, tok);
+ break;
+ case '"':
+ do { tok.push_back(ch); } while ((ch = in.get()) != '"');
+ YIELD(stk, tok + '"');
+ break;
+ case '(':
+ stk.push(SExp());
+ break;
+ case ')':
+ switch (stk.size()) {
+ case 0:
+ throw SyntaxError("Unexpected ')'");
+ case 1:
+ PUSH(stk, tok);
+ return stk.top();
+ default:
+ PUSH(stk, tok);
+ SExp l = stk.top();
+ stk.pop();
+ stk.top().list.push_back(l);
+ }
+ break;
+ default:
+ tok += ch;
+ }
+ }
+ switch (stk.size()) {
+ case 0: return tok;
+ case 1: return stk.top();
+ default: throw SyntaxError("Missing ')'");
+ }
+ return SExp();
+}
+
+
+/***************************************************************************
+ * Abstract Syntax Tree *
+ ***************************************************************************/
+
+struct TEnv; ///< Type-Time Environment
+struct CEnv; ///< Compile-Time Environment
+
+/// Base class for all AST nodes
+struct AST {
+ virtual ~AST() {}
+ virtual bool contains(AST* child) const { return false; }
+ virtual bool operator!=(const AST& o) const { return !operator==(o); }
+ virtual bool operator==(const AST& o) const = 0;
+ virtual string str() const = 0;
+ virtual void constrain(TEnv& tenv) const {}
+ virtual void lift(CEnv& cenv) {}
+ virtual Value* compile(CEnv& cenv) = 0;
+};
+
+/// Literal
+template<typename VT>
+struct ASTLiteral : public AST {
+ ASTLiteral(VT v) : val(v) {}
+ bool operator==(const AST& rhs) const {
+ const ASTLiteral<VT>* r = dynamic_cast<const ASTLiteral<VT>*>(&rhs);
+ return r && val == r->val;
+ }
+ string str() const { ostringstream s; s << val; return s.str(); }
+ void constrain(TEnv& tenv) const;
+ Value* compile(CEnv& cenv);
+ const VT val;
+};
+
+/// Symbol, e.g. "a"
+struct ASTSymbol : public AST {
+ ASTSymbol(const string& s) : cppstr(s) {}
+ bool operator==(const AST& rhs) const { return this == &rhs; }
+ string str() const { return cppstr; }
+ Value* compile(CEnv& cenv);
+private:
+ const string cppstr;
+};
+
+/// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)"
+struct ASTTuple : public AST, public vector<AST*> {
+ ASTTuple(const vector<AST*>& t=vector<AST*>()) : vector<AST*>(t) {}
+ ASTTuple(size_t size) : vector<AST*>(size) {}
+ ASTTuple(AST* ast, ...) {
+ push_back(ast);
+ va_list args;
+ va_start(args, ast);
+ for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
+ push_back(a);
+ va_end(args);
+ }
+ string str() const {
+ string ret = "(";
+ for (size_t i = 0; i != size(); ++i)
+ ret += at(i)->str() + ((i != size() - 1) ? " " : "");
+ return ret + ")";
+ }
+ bool operator==(const AST& rhs) const {
+ const ASTTuple* rt = dynamic_cast<const ASTTuple*>(&rhs);
+ if (!rt) return false;
+ if (rt->size() != size()) return false;
+ const_iterator l = begin();
+ FOREACH(const_iterator, r, *rt) {
+ AST* mine = *l++;
+ AST* other = *r;
+ if (!(*mine == *other))
+ return false;
+ }
+ return true;
+ }
+ void lift(CEnv& cenv) {
+ FOREACH(iterator, t, *this)
+ (*t)->lift(cenv);
+ }
+ bool isForm(const string& f) { return !empty() && at(0)->str() == f; }
+ bool contains(AST* child) const;
+ void constrain(TEnv& tenv) const;
+ Value* compile(CEnv& cenv) { return NULL; }
+};
+
+/// Type Expression, e.g. "(Int)" or "(Fn ((Int)) (Float))"
+struct AType : public ASTTuple {
+ AType(const ASTTuple& t) : ASTTuple(t), var(false), ctype(0) {}
+ AType(unsigned i) : var(true), ctype(0), id(i) {}
+ AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) {
+ push_back(n);
+ }
+ string str() const {
+ if (var) {
+ ostringstream s; s << "?" << id; return s.str();
+ } else {
+ return ASTTuple::str();
+ }
+ }
+ void constrain(TEnv& tenv) const {}
+ Value* compile(CEnv& cenv) { return NULL; }
+ bool concrete() const {
+ if (var) return false;
+ FOREACH(const_iterator, t, *this) {
+ AType* kid = dynamic_cast<AType*>(*t);
+ if (kid && !kid->concrete())
+ return false;
+ }
+ return true;
+ }
+ bool operator==(const AST& rhs) const {
+ const AType* rt = dynamic_cast<const AType*>(&rhs);
+ if (!rt)
+ return false;
+ else if (var && rt->var)
+ return id == rt->id;
+ else if (!var && !rt->var)
+ return ASTTuple::operator==(rhs);
+ return false;
+ }
+ bool var;
+ const Type* ctype;
+ unsigned id;
+};
+
+/// Closure (first-class function with captured lexical bindings)
+struct ASTClosure : public ASTTuple {
+ ASTClosure(ASTTuple* p, AST* b) : ASTTuple(0, p, b), prot(p), func(0) {}
+ bool operator==(const AST& rhs) const { return this == &rhs; }
+ string str() const { ostringstream s; s << this; return s.str(); }
+ void constrain(TEnv& tenv) const;
+ void lift(CEnv& cenv);
+ Value* compile(CEnv& cenv);
+ ASTTuple* const prot;
+private:
+ Function* func;
+};
+
+/// Function call/application, e.g. "(func arg1 arg2)"
+struct ASTCall : public ASTTuple {
+ ASTCall(const ASTTuple& t) : ASTTuple(t) {}
+ void constrain(TEnv& tenv) const;
+ void lift(CEnv& cenv);
+ Value* compile(CEnv& cenv);
+};
+
+/// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))"
+struct ASTDefinition : public ASTCall {
+ ASTDefinition(const ASTTuple& t) : ASTCall(t) {}
+ void constrain(TEnv& tenv) const;
+ void lift(CEnv& cenv);
+ Value* compile(CEnv& cenv);
+};
+
+/// Conditional special form, e.g. "(if cond thenexp elseexp)"
+struct ASTIf : public ASTCall {
+ ASTIf(const ASTTuple& t) : ASTCall(t) {}
+ void constrain(TEnv& tenv) const;
+ Value* compile(CEnv& cenv);
+};
+
+/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
+struct ASTPrimitive : public ASTCall {
+ ASTPrimitive(const ASTTuple& t, int o, int a=0) : ASTCall(t), op(o), arg(a) {}
+ void constrain(TEnv& tenv) const;
+ Value* compile(CEnv& cenv);
+ unsigned op;
+ unsigned arg;
+};
+
+
+/***************************************************************************
+ * Parser - S-Expressions (SExp) -> AST Nodes (AST) *
+ ***************************************************************************/
+
+/// LLVM Operation
+struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; };
+
+typedef Op UD; // User Data argument for parse functions
+
+// Parse Time Environment (symbol table)
+struct PEnv : private map<const string, ASTSymbol*> {
+ typedef AST* (*PF)(PEnv&, const SExp::List&, UD); // Parse Function
+ struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; };
+ map<string, Parser> parsers;
+ void reg(const string& s, const Parser& p) {
+ parsers.insert(make_pair(sym(s)->str(), p));
+ }
+ const Parser* parser(const string& s) const {
+ map<string, Parser>::const_iterator i = parsers.find(s);
+ return (i != parsers.end()) ? &i->second : NULL;
+ }
+ ASTSymbol* sym(const string& s) {
+ const const_iterator i = find(s);
+ return ((i != end())
+ ? i->second
+ : insert(make_pair(s, new ASTSymbol(s))).first->second);
+ }
+};
+
+/// The fundamental parser method
+static AST* parseExpression(PEnv& penv, const SExp& exp);
+
+static ASTTuple
+pmap(PEnv& penv, const SExp::List& l)
+{
+ ASTTuple ret(l.size());
+ size_t n = 0;
+ FOREACH(SExp::List::const_iterator, i, l)
+ ret[n++] = parseExpression(penv, *i);
+ return ret;
+}
+
+static AST*
+parseExpression(PEnv& penv, const SExp& exp)
+{
+ if (exp.type == SExp::LIST) {
+ if (exp.list.empty()) throw SyntaxError("Call to empty list");
+ if (exp.list.front().type == SExp::ATOM) {
+ const PEnv::Parser* handler = penv.parser(exp.list.front().atom);
+ if (handler) // Dispatch to parse function
+ return handler->pf(penv, exp.list, handler->ud);
+ }
+ return new ASTCall(pmap(penv, exp.list)); // Parse as regular call
+ } else if (isdigit(exp.atom[0])) {
+ if (exp.atom.find('.') == string::npos)
+ return new ASTLiteral<int32_t>(strtol(exp.atom.c_str(), NULL, 10));
+ else
+ return new ASTLiteral<float>(strtod(exp.atom.c_str(), NULL));
+ }
+ return penv.sym(exp.atom);
+}
+
+// Special forms
+
+static AST*
+parseIf(PEnv& penv, const SExp::List& c, UD)
+ { return new ASTIf(pmap(penv, c)); }
+
+static AST*
+parseDef(PEnv& penv, const SExp::List& c, UD)
+ { return new ASTDefinition(pmap(penv, c)); }
+
+static AST*
+parsePrim(PEnv& penv, const SExp::List& c, UD data)
+ { return new ASTPrimitive(pmap(penv, c), data.op, data.arg); }
+
+static AST*
+parseFn(PEnv& penv, const SExp::List& c, UD)
+{
+ SExp::List::const_iterator a = c.begin(); ++a;
+ return new ASTClosure(
+ new ASTTuple(pmap(penv, (*a++).list)),
+ parseExpression(penv, *a++));
+}
+
+
+/***************************************************************************
+ * Generic Lexical Environment *
+ ***************************************************************************/
+
+template<typename K, typename V>
+struct Env : public list< map<K,V> > {
+ typedef map<K,V> Frame;
+ Env() : list<Frame>(1) {}
+ void push_front() { list<Frame>::push_front(Frame()); }
+ const V& def(const K& k, const V& v) {
+ typename Frame::iterator existing = this->front().find(k);
+ if (existing != this->front().end() && existing->second != v)
+ throw SyntaxError("Redefinition");
+ return (this->front()[k] = v);
+ }
+ V* ref(const K& name) {
+ typename Frame::iterator s;
+ for (typename Env::iterator i = this->begin(); i != this->end(); ++i)
+ if ((s = i ->find(name)) != i->end())
+ return &s->second;
+ return 0;
+ }
+};
+
+
+/***************************************************************************
+ * Typing *
+ ***************************************************************************/
+
+struct TypeError : public Error { TypeError (const char* m) : Error(m) {} };
+
+struct TSubst : public map<AType*, AType*> {
+ TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); }
+};
+
+/// Type-Time Environment
+struct TEnv {
+ TEnv(PEnv& p) : penv(p), varID(1) {}
+ typedef map<const AST*, AType*> Types;
+ typedef list< pair<AType*, AType*> > Constraints;
+ AType* var() { return new AType(varID++); }
+ AType* type(const AST* ast) {
+ Types::iterator t = types.find(ast);
+ return (t != types.end()) ? t->second : (types[ast] = var());
+ }
+ AType* named(const string& name) const {
+ Types::const_iterator i = namedTypes.find(penv.sym(name));
+ if (i == namedTypes.end()) throw TypeError("Unknown named type");
+ return i->second;
+ }
+ void name(const string& name, const Type* type) {
+ ASTSymbol* sym = penv.sym(name);
+ namedTypes[sym] = new AType(penv.sym(name), type);
+ }
+ void constrain(const AST* o, AType* t) {
+ constraints.push_back(make_pair(type(o), t));
+ }
+ void solve() { apply(unify(constraints)); }
+ void apply(const TSubst& substs);
+ static TSubst unify(const Constraints& c);
+ PEnv& penv;
+ Types types;
+ Types namedTypes;
+ Constraints constraints;
+ unsigned varID;
+};
+
+#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End)
+
+void
+ASTTuple::constrain(TEnv& tenv) const
+{
+ AType* t = new AType(ASTTuple());
+ FOREACH(const_iterator, p, *this) {
+ (*p)->constrain(tenv);
+ t->push_back(tenv.type(*p));
+ }
+ tenv.constrain(tenv.type(this), t);
+}
+
+void
+ASTClosure::constrain(TEnv& tenv) const
+{
+ prot->constrain(tenv);
+ at(2)->constrain(tenv);
+ AType* bodyT = tenv.type(at(2));
+ tenv.constrain(this, new AType(
+ ASTTuple(tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0)));
+}
+
+void
+ASTCall::constrain(TEnv& tenv) const
+{
+ FOREACH(const_iterator, p, *this)
+ (*p)->constrain(tenv);
+ AType* retT = tenv.type(this);
+ tenv.constrain(at(0), new AType(ASTTuple(
+ tenv.penv.sym("Fn"), tenv.var(), retT, NULL)));
+}
+
+void
+ASTDefinition::constrain(TEnv& tenv) const
+{
+ if (size() != 3)
+ throw SyntaxError("\"def\" not passed 2 arguments");
+ if (!dynamic_cast<const ASTSymbol*>(at(1)))
+ throw SyntaxError("\"def\" name is not a symbol");
+ FOREACH(const_iterator, p, *this)
+ (*p)->constrain(tenv);
+ AType* tvar = tenv.type(this);
+ tenv.constrain(at(1), tvar);
+ tenv.constrain(at(2), tvar);
+}
+
+void
+ASTIf::constrain(TEnv& tenv) const
+{
+ FOREACH(const_iterator, p, *this)
+ (*p)->constrain(tenv);
+ AType* tvar = tenv.type(this);
+ tenv.constrain(at(1), tenv.named("Bool"));
+ tenv.constrain(at(2), tvar);
+ tenv.constrain(at(3), tvar);
+}
+
+void
+ASTPrimitive::constrain(TEnv& tenv) const
+{
+ FOREACH(const_iterator, p, *this)
+ (*p)->constrain(tenv);
+ if (OP_IS_A(op, Instruction::BinaryOps)) {
+ if (size() <= 1) throw SyntaxError("Primitive call with 0 args");
+ AType* tvar = tenv.type(this);
+ for (size_t i = 1; i < size(); ++i)
+ tenv.constrain(at(i), tvar);
+ } else if (op == Instruction::ICmp) {
+ if (size() != 3) throw SyntaxError("Comparison call with != 2 args");
+ tenv.constrain(at(1), tenv.type(at(2)));
+ tenv.constrain(this, tenv.named("Bool"));
+ } else {
+ throw TypeError("Unknown primitive");
+ }
+}
+
+static void
+substitute(ASTTuple* tup, AST* from, AST* to)
+{
+ if (!tup) return;
+ for (size_t i = 0; i < tup->size(); ++i)
+ if (*tup->at(i) == *from)
+ tup->at(i) = to;
+ else
+ substitute(dynamic_cast<ASTTuple*>(tup->at(i)), from, to);
+}
+
+bool
+ASTTuple::contains(AST* child) const
+{
+ if (*this == *child) return true;
+ FOREACH(const_iterator, p, *this)
+ if (**p == *child || (*p)->contains(child))
+ return true;
+ return false;
+}
+
+TSubst
+compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1
+{
+ TSubst r;
+ for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
+ TSubst::const_iterator d = delta.find(g->second);
+ r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second));
+ }
+ for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
+ if (gamma.find(d->first) == gamma.end())
+ r.insert(*d);
+ }
+ return r;
+}
+
+void
+substConstraints(TEnv::Constraints& constraints, AType* s, AType* t)
+{
+ for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) {
+ TEnv::Constraints::iterator next = c; ++next;
+ if (*c->first == *s) c->first = t;
+ if (*c->second == *s) c->second = t;
+ substitute(c->first, s, t);
+ substitute(c->second, s, t);
+ c = next;
+ }
+}
+
+TSubst
+TEnv::unify(const Constraints& constraints) // TAPL 22.4
+{
+ if (constraints.empty()) return TSubst();
+ AType* s = constraints.begin()->first;
+ AType* t = constraints.begin()->second;
+ Constraints cp = constraints;
+ cp.erase(cp.begin());
+
+ if (*s == *t) {
+ return unify(cp);
+ } else if (s->var && !t->contains(s)) {
+ substConstraints(cp, s, t);
+ return compose(unify(cp), TSubst(s, t));
+ } else if (t->var && !s->contains(t)) {
+ substConstraints(cp, t, s);
+ return compose(unify(cp), TSubst(t, s));
+ } else if (s->isForm("Fn") && t->isForm("Fn")) {
+ AType* s1 = dynamic_cast<AType*>(s->at(1));
+ AType* t1 = dynamic_cast<AType*>(t->at(1));
+ AType* s2 = dynamic_cast<AType*>(s->at(2));
+ AType* t2 = dynamic_cast<AType*>(t->at(2));
+ assert(s1 && t1 && s2 && t2);
+ cp.push_back(make_pair(s1, t1));
+ cp.push_back(make_pair(s2, t2));
+ return unify(cp);
+ } else {
+ throw TypeError("Type unification failed");
+ }
+}
+
+void
+TEnv::apply(const TSubst& substs)
+{
+ FOREACH(TSubst::const_iterator, s, substs)
+ FOREACH(Types::iterator, t, types)
+ if (*t->second == *s->first)
+ t->second = s->second;
+}
+
+
+/***************************************************************************
+ * Code Generation *
+ ***************************************************************************/
+
+struct CompileError : public Error { CompileError(const char* m) : Error(m) {} };
+
+class PEnv;
+
+/// Compile-Time Environment
+struct CEnv {
+ CEnv(PEnv& p, Module* m, const TargetData* target)
+ : penv(p), tenv(p), module(m), emp(module), opt(&emp), symID(0)
+ {
+ // Set up the optimizer pipeline:
+ opt.add(new TargetData(*target)); // Register target arch
+ opt.add(createInstructionCombiningPass()); // Simple optimizations
+ opt.add(createReassociatePass()); // Reassociate expressions
+ opt.add(createGVNPass()); // Eliminate Common Subexpressions
+ opt.add(createCFGSimplificationPass()); // Simplify control flow
+ }
+ string gensym(const char* base="_") {
+ ostringstream s; s << base << symID++; return s.str();
+ }
+ void push() { code.push_front(); vals.push_front(); }
+ void pop() { code.pop_front(); vals.pop_front(); }
+ Value* compile(AST* obj) {
+ Value** v = vals.ref(obj);
+ return (v) ? *v : vals.def(obj, obj->compile(*this));
+ }
+ void precompile(AST* obj, Value* value) {
+ assert(!vals.ref(obj));
+ vals.def(obj, value);
+ }
+ void optimise(Function& f) { verifyFunction(f); opt.run(f); }
+ typedef Env<const AST*, AST*> Code;
+ typedef Env<const AST*, Value*> Vals;
+ PEnv& penv;
+ TEnv tenv;
+ IRBuilder<> builder;
+ Module* module;
+ ExistingModuleProvider emp;
+ FunctionPassManager opt;
+ unsigned symID;
+ Code code;
+ Vals vals;
+};
+
+#define LITERAL(CT, NAME, COMPILED) \
+template<> Value* \
+ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
+template<> void \
+ASTLiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); }
+
+/// Literal template instantiations
+LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true));
+LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val));
+LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false));
+
+static Function*
+compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT)
+{
+ Function::LinkageTypes linkage = Function::ExternalLinkage;
+
+ vector<const Type*> cprot;
+ for (size_t i = 0; i < prot.size(); ++i) {
+ const AType* at = cenv.tenv.type(prot.at(i));
+ if (!at->ctype || at->var) throw CompileError("Parameter is untyped");
+ cprot.push_back(at->ctype);
+ }
+
+ if (!retT) throw CompileError("Return is untyped");
+ FunctionType* fT = FunctionType::get(retT, cprot, false);
+ Function* f = Function::Create(fT, linkage, name, cenv.module);
+
+ if (f->getName() != name) {
+ f->eraseFromParent();
+ throw CompileError("Function redefined");
+ }
+
+ // Set argument names in generated code
+ Function::arg_iterator a = f->arg_begin();
+ for (size_t i = 0; i != prot.size(); ++a, ++i)
+ a->setName(prot.at(i)->str());
+
+ return f;
+}
+
+Value*
+ASTSymbol::compile(CEnv& cenv)
+{
+ AST** c = cenv.code.ref(this);
+ if (!c) throw SyntaxError((string("Undefined symbol: ") + cppstr).c_str());
+ return cenv.vals.def(this, cenv.compile(*c));
+}
+
+void
+ASTClosure::lift(CEnv& cenv)
+{
+ if (cenv.tenv.type(at(2))->var)
+ throw CompileError("Closure with untyped body lifted");
+ for (size_t i = 0; i < prot->size(); ++i)
+ if (cenv.tenv.type(prot->at(i))->var)
+ throw CompileError("Closure with untyped parameter lifted");
+
+ assert(!func);
+ cenv.push();
+
+ // Write function declaration
+ Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(at(2))->ctype);
+ BasicBlock* bb = BasicBlock::Create("entry", f);
+ cenv.builder.SetInsertPoint(bb);
+
+ // Bind argument values in CEnv
+ vector<Value*> args;
+ const_iterator p = prot->begin();
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
+ cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a);
+
+ // Write function body
+ try {
+ cenv.precompile(this, f); // Define our value first for recursion
+ Value* retVal = cenv.compile(at(2));
+ cenv.builder.CreateRet(retVal); // Finish function
+ cenv.optimise(*f);
+ func = f;
+ } catch (exception e) {
+ f->eraseFromParent(); // Error reading body, remove function
+ throw e;
+ }
+
+ assert(func);
+ cenv.pop();
+}
+
+Value*
+ASTClosure::compile(CEnv& cenv)
+{
+ assert(func);
+ return func; // Function was already compiled in the lifting pass
+}
+
+void
+ASTCall::lift(CEnv& cenv)
+{
+ ASTClosure* c = dynamic_cast<ASTClosure*>(at(0));
+ if (!c) {
+ AST** val = cenv.code.ref(at(0));
+ c = (val) ? dynamic_cast<ASTClosure*>(*val) : c;
+ }
+
+ // Lift arguments
+ for (size_t i = 1; i < size(); ++i)
+ at(i)->lift(cenv);
+
+ if (!c) return;
+
+ // Extend environment with bound and typed parameters
+ cenv.push();
+ if (c->prot->size() != size() - 1)
+ throw CompileError("Call to closure with mismatched arguments");
+
+ for (size_t i = 1; i < size(); ++i)
+ cenv.code.def(c->prot->at(i-1), at(i));
+
+ at(0)->lift(cenv); // Lift called closure
+ cenv.pop(); // Restore environment
+}
+
+Value*
+ASTCall::compile(CEnv& cenv)
+{
+ ASTClosure* c = dynamic_cast<ASTClosure*>(at(0));
+ if (!c) {
+ AST** val = cenv.code.ref(at(0));
+ c = (val) ? dynamic_cast<ASTClosure*>(*val) : c;
+ }
+
+ assert(c);
+ Function* f = dynamic_cast<Function*>(cenv.compile(c));
+ if (!f) throw CompileError("Callee failed to compile");
+
+ vector<Value*> params(size() - 1);
+ for (size_t i = 1; i < size(); ++i)
+ params[i-1] = cenv.compile(at(i));
+
+ return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp");
+}
+
+void
+ASTDefinition::lift(CEnv& cenv)
+{
+ cenv.code.def((ASTSymbol*)at(1), at(2)); // Define first for recursion
+ at(2)->lift(cenv);
+}
+
+Value*
+ASTDefinition::compile(CEnv& cenv)
+{
+ return cenv.compile(at(2));
+}
+
+Value*
+ASTIf::compile(CEnv& cenv)
+{
+ typedef vector< pair<Value*, BasicBlock*> > Branches;
+ Function* parent = cenv.builder.GetInsertBlock()->getParent();
+ BasicBlock* mergeBB = BasicBlock::Create("endif");
+ BasicBlock* nextBB = NULL;
+ Branches branches;
+ ostringstream ss;
+ for (size_t i = 1; i < size() - 1; i += 2) {
+ Value* condV = cenv.compile(at(i));
+
+ ss.str(""); ss << "then" << ((i + 1) / 2);
+ BasicBlock* thenBB = BasicBlock::Create(ss.str());
+
+ ss.str(""); ss << "else" << ((i + 1) / 2);
+ nextBB = BasicBlock::Create(ss.str());
+
+ cenv.builder.CreateCondBr(condV, thenBB, nextBB);
+
+ // Emit then block for this condition
+ parent->getBasicBlockList().push_back(thenBB);
+ cenv.builder.SetInsertPoint(thenBB);
+ Value* thenV = cenv.compile(at(i + 1));
+ cenv.builder.CreateBr(mergeBB);
+ branches.push_back(make_pair(thenV, cenv.builder.GetInsertBlock()));
+
+ parent->getBasicBlockList().push_back(nextBB);
+ cenv.builder.SetInsertPoint(nextBB);
+ }
+
+ // Emit else block
+ cenv.builder.SetInsertPoint(nextBB);
+ Value* elseV = cenv.compile(at(size() - 1));
+ cenv.builder.CreateBr(mergeBB);
+ branches.push_back(make_pair(elseV, cenv.builder.GetInsertBlock()));
+
+ // Emit merge block (Phi node)
+ parent->getBasicBlockList().push_back(mergeBB);
+ cenv.builder.SetInsertPoint(mergeBB);
+ PHINode* pn = cenv.builder.CreatePHI(cenv.tenv.type(this)->ctype, "ifval");
+
+ for (Branches::iterator i = branches.begin(); i != branches.end(); ++i)
+ pn->addIncoming(i->first, i->second);
+
+ return pn;
+}
+
+Value*
+ASTPrimitive::compile(CEnv& cenv)
+{
+ if (size() < 3) throw SyntaxError("Too few arguments");
+ Value* a = cenv.compile(at(1));
+ Value* b = cenv.compile(at(2));
+
+ if (OP_IS_A(op, Instruction::BinaryOps)) {
+ const Instruction::BinaryOps bo = (Instruction::BinaryOps)op;
+ if (size() == 2)
+ return cenv.compile(at(1));
+ Value* val = cenv.builder.CreateBinOp(bo, a, b);
+ for (size_t i = 3; i < size(); ++i)
+ val = cenv.builder.CreateBinOp(bo, val, cenv.compile(at(i)));
+ return val;
+ } else if (op == Instruction::ICmp) {
+ bool isInt = cenv.tenv.type(at(1))->str() == "(Int)";
+ if (isInt) {
+ return cenv.builder.CreateICmp((CmpInst::Predicate)arg, a, b);
+ } else {
+ // Translate to floating point operation
+ switch (arg) {
+ case CmpInst::ICMP_EQ: arg = CmpInst::FCMP_OEQ; break;
+ case CmpInst::ICMP_NE: arg = CmpInst::FCMP_ONE; break;
+ case CmpInst::ICMP_SGT: arg = CmpInst::FCMP_OGT; break;
+ case CmpInst::ICMP_SGE: arg = CmpInst::FCMP_OGE; break;
+ case CmpInst::ICMP_SLT: arg = CmpInst::FCMP_OLT; break;
+ case CmpInst::ICMP_SLE: arg = CmpInst::FCMP_OLE; break;
+ default: throw CompileError("Unknown primitive");
+ }
+ return cenv.builder.CreateFCmp((CmpInst::Predicate)arg, a, b);
+ }
+ }
+ throw CompileError("Unknown primitive");
+}
+
+
+/***************************************************************************
+ * REPL *
+ ***************************************************************************/
+
+int
+main()
+{
+#define PRIM(O, A) PEnv::Parser(parsePrim, Op(Instruction:: O, A))
+ PEnv penv;
+ penv.reg("fn", PEnv::Parser(parseFn, Op()));
+ penv.reg("if", PEnv::Parser(parseIf, Op()));
+ penv.reg("def", PEnv::Parser(parseDef, Op()));
+ penv.reg("+", PRIM(Add, 0));
+ penv.reg("-", PRIM(Sub, 0));
+ penv.reg("*", PRIM(Mul, 0));
+ penv.reg("/", PRIM(FDiv, 0));
+ penv.reg("%", PRIM(FRem, 0));
+ penv.reg("&", PRIM(And, 0));
+ penv.reg("|", PRIM(Or, 0));
+ penv.reg("^", PRIM(Xor, 0));
+ penv.reg("=", PRIM(ICmp, CmpInst::ICMP_EQ));
+ penv.reg("!=", PRIM(ICmp, CmpInst::ICMP_NE));
+ penv.reg(">", PRIM(ICmp, CmpInst::ICMP_SGT));
+ penv.reg(">=", PRIM(ICmp, CmpInst::ICMP_SGE));
+ penv.reg("<", PRIM(ICmp, CmpInst::ICMP_SLT));
+ penv.reg("<=", PRIM(ICmp, CmpInst::ICMP_SLE));
+
+ Module* module = new Module("repl");
+ ExecutionEngine* engine = ExecutionEngine::create(module);
+ CEnv cenv(penv, module, engine->getTargetData());
+
+ cenv.tenv.name("Bool", Type::Int1Ty);
+ cenv.tenv.name("Int", Type::Int32Ty);
+ cenv.tenv.name("Float", Type::FloatTy);
+ cenv.code.def(penv.sym("true"), new ASTLiteral<bool>(true));
+ cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false));
+
+ while (1) {
+ std::cout << "() ";
+ std::cout.flush();
+ SExp exp = readExpression(std::cin);
+ if (exp.type == SExp::LIST && exp.list.empty())
+ break;
+
+ try {
+ AST* body = parseExpression(penv, exp); // Parse input
+ body->constrain(cenv.tenv); // Constrain types
+ cenv.tenv.solve(); // Solve and apply type constraints
+
+ AType* bodyT = cenv.tenv.type(body);
+ if (!bodyT) throw TypeError("REPL call to untyped body");
+ if (bodyT->var) throw TypeError("REPL call to variable typed body");
+
+ body->lift(cenv);
+
+ if (bodyT->ctype) {
+ // Create anonymous function to insert code into.
+ ASTTuple* prot = new ASTTuple();
+ Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype);
+ BasicBlock* bb = BasicBlock::Create("entry", f);
+ cenv.builder.SetInsertPoint(bb);
+ try {
+ Value* retVal = cenv.compile(body);
+ cenv.builder.CreateRet(retVal); // Finish function
+ cenv.optimise(*f);
+ } catch (SyntaxError e) {
+ f->eraseFromParent(); // Error reading body, remove function
+ throw e;
+ }
+ void* fp = engine->getPointerToFunction(f);
+ if (bodyT->ctype == Type::Int32Ty)
+ std::cout << "; " << ((int32_t (*)())fp)();
+ else if (bodyT->ctype == Type::FloatTy)
+ std::cout << "; " << ((float (*)())fp)();
+ else if (bodyT->ctype == Type::Int1Ty)
+ std::cout << "; " << ((bool (*)())fp)();
+ } else {
+ Value* val = cenv.compile(body);
+ std::cout << "; " << val;
+ }
+ std::cout << " : " << cenv.tenv.type(body)->str() << endl;
+
+ } catch (Error e) {
+ std::cerr << "Error: " << e.what() << endl;
+ }
+ }
+
+ std::cout << endl << "Generated code:" << endl;
+ module->dump();
+ return 0;
+}
+