diff options
author | David Robillard <d@drobilla.net> | 2009-01-28 23:26:50 +0000 |
---|---|---|
committer | David Robillard <d@drobilla.net> | 2009-01-28 23:26:50 +0000 |
commit | 043d037c5e4d7b5e86b257458351ff9293afca19 (patch) | |
tree | fdee31a57f6cc0e3793a7c43b950c828e5d67518 /tuplr.cpp | |
parent | bb4c9c5fd9b47092eb526fa177771c3eac7815e0 (diff) | |
download | resp-043d037c5e4d7b5e86b257458351ff9293afca19.tar.gz resp-043d037c5e4d7b5e86b257458351ff9293afca19.tar.bz2 resp-043d037c5e4d7b5e86b257458351ff9293afca19.zip |
Move stuff.
git-svn-id: http://svn.drobilla.net/resp/tuplr@37 ad02d1e2-f140-0410-9f75-f8b11f17cedd
Diffstat (limited to 'tuplr.cpp')
-rw-r--r-- | tuplr.cpp | 981 |
1 files changed, 981 insertions, 0 deletions
diff --git a/tuplr.cpp b/tuplr.cpp new file mode 100644 index 0000000..ada3073 --- /dev/null +++ b/tuplr.cpp @@ -0,0 +1,981 @@ +/* Tuplr: A minimalist programming language + * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net> + * + * Tuplr is free software: you can redistribute it and/or modify it under + * the terms of the GNU Affero General Public License as published by the + * Free Software Foundation, either version 3 of the License, or (at your + * option) any later version. + * + * Tuplr is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Tuplr. If not, see <http://www.gnu.org/licenses/>. + */ + +#include <stdarg.h> +#include <iostream> +#include <list> +#include <map> +#include <sstream> +#include <stack> +#include <string> +#include <vector> +#include "llvm/Analysis/Verifier.h" +#include "llvm/DerivedTypes.h" +#include "llvm/ExecutionEngine/ExecutionEngine.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/ModuleProvider.h" +#include "llvm/PassManager.h" +#include "llvm/Support/IRBuilder.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Scalar.h" + +#define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i) + +using namespace llvm; +using namespace std; + +struct Error : public std::exception { + Error(const char* m) : msg(m) {} + const char* what() const throw() { return msg; } + const char* msg; +}; + +template<typename A> +struct Exp { // ::= Atom | (Exp*) + Exp() : type(LIST) {} + Exp(const A& a) : type(ATOM), atom(a) {} + enum { ATOM, LIST } type; + typedef std::vector< Exp<A> > List; + A atom; + List list; +}; + + +/*************************************************************************** + * S-Expression Lexer :: text -> S-Expressions (SExp) * + ***************************************************************************/ + +struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} }; +typedef Exp<string> SExp; + +static SExp +readExpression(std::istream& in) +{ +#define PUSH(s, t) { if (t != "") { s.top().list.push_back(t); t = ""; } } +#define YIELD(s, t) { if (s.empty()) return t; else PUSH(s, t) } + stack<SExp> stk; + string tok; + while (char ch = in.get()) { + switch (ch) { + case EOF: + return SExp(); + case ' ': case '\t': case '\n': + if (tok != "") YIELD(stk, tok); + break; + case '"': + do { tok.push_back(ch); } while ((ch = in.get()) != '"'); + YIELD(stk, tok + '"'); + break; + case '(': + stk.push(SExp()); + break; + case ')': + switch (stk.size()) { + case 0: + throw SyntaxError("Unexpected ')'"); + case 1: + PUSH(stk, tok); + return stk.top(); + default: + PUSH(stk, tok); + SExp l = stk.top(); + stk.pop(); + stk.top().list.push_back(l); + } + break; + default: + tok += ch; + } + } + switch (stk.size()) { + case 0: return tok; + case 1: return stk.top(); + default: throw SyntaxError("Missing ')'"); + } + return SExp(); +} + + +/*************************************************************************** + * Abstract Syntax Tree * + ***************************************************************************/ + +struct TEnv; ///< Type-Time Environment +struct CEnv; ///< Compile-Time Environment + +/// Base class for all AST nodes +struct AST { + virtual ~AST() {} + virtual bool contains(AST* child) const { return false; } + virtual bool operator!=(const AST& o) const { return !operator==(o); } + virtual bool operator==(const AST& o) const = 0; + virtual string str() const = 0; + virtual void constrain(TEnv& tenv) const {} + virtual void lift(CEnv& cenv) {} + virtual Value* compile(CEnv& cenv) = 0; +}; + +/// Literal +template<typename VT> +struct ASTLiteral : public AST { + ASTLiteral(VT v) : val(v) {} + bool operator==(const AST& rhs) const { + const ASTLiteral<VT>* r = dynamic_cast<const ASTLiteral<VT>*>(&rhs); + return r && val == r->val; + } + string str() const { ostringstream s; s << val; return s.str(); } + void constrain(TEnv& tenv) const; + Value* compile(CEnv& cenv); + const VT val; +}; + +/// Symbol, e.g. "a" +struct ASTSymbol : public AST { + ASTSymbol(const string& s) : cppstr(s) {} + bool operator==(const AST& rhs) const { return this == &rhs; } + string str() const { return cppstr; } + Value* compile(CEnv& cenv); +private: + const string cppstr; +}; + +/// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)" +struct ASTTuple : public AST, public vector<AST*> { + ASTTuple(const vector<AST*>& t=vector<AST*>()) : vector<AST*>(t) {} + ASTTuple(size_t size) : vector<AST*>(size) {} + ASTTuple(AST* ast, ...) { + push_back(ast); + va_list args; + va_start(args, ast); + for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*)) + push_back(a); + va_end(args); + } + string str() const { + string ret = "("; + for (size_t i = 0; i != size(); ++i) + ret += at(i)->str() + ((i != size() - 1) ? " " : ""); + return ret + ")"; + } + bool operator==(const AST& rhs) const { + const ASTTuple* rt = dynamic_cast<const ASTTuple*>(&rhs); + if (!rt) return false; + if (rt->size() != size()) return false; + const_iterator l = begin(); + FOREACH(const_iterator, r, *rt) { + AST* mine = *l++; + AST* other = *r; + if (!(*mine == *other)) + return false; + } + return true; + } + void lift(CEnv& cenv) { + FOREACH(iterator, t, *this) + (*t)->lift(cenv); + } + bool isForm(const string& f) { return !empty() && at(0)->str() == f; } + bool contains(AST* child) const; + void constrain(TEnv& tenv) const; + Value* compile(CEnv& cenv) { return NULL; } +}; + +/// Type Expression, e.g. "(Int)" or "(Fn ((Int)) (Float))" +struct AType : public ASTTuple { + AType(const ASTTuple& t) : ASTTuple(t), var(false), ctype(0) {} + AType(unsigned i) : var(true), ctype(0), id(i) {} + AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) { + push_back(n); + } + string str() const { + if (var) { + ostringstream s; s << "?" << id; return s.str(); + } else { + return ASTTuple::str(); + } + } + void constrain(TEnv& tenv) const {} + Value* compile(CEnv& cenv) { return NULL; } + bool concrete() const { + if (var) return false; + FOREACH(const_iterator, t, *this) { + AType* kid = dynamic_cast<AType*>(*t); + if (kid && !kid->concrete()) + return false; + } + return true; + } + bool operator==(const AST& rhs) const { + const AType* rt = dynamic_cast<const AType*>(&rhs); + if (!rt) + return false; + else if (var && rt->var) + return id == rt->id; + else if (!var && !rt->var) + return ASTTuple::operator==(rhs); + return false; + } + bool var; + const Type* ctype; + unsigned id; +}; + +/// Closure (first-class function with captured lexical bindings) +struct ASTClosure : public ASTTuple { + ASTClosure(ASTTuple* p, AST* b) : ASTTuple(0, p, b), prot(p), func(0) {} + bool operator==(const AST& rhs) const { return this == &rhs; } + string str() const { ostringstream s; s << this; return s.str(); } + void constrain(TEnv& tenv) const; + void lift(CEnv& cenv); + Value* compile(CEnv& cenv); + ASTTuple* const prot; +private: + Function* func; +}; + +/// Function call/application, e.g. "(func arg1 arg2)" +struct ASTCall : public ASTTuple { + ASTCall(const ASTTuple& t) : ASTTuple(t) {} + void constrain(TEnv& tenv) const; + void lift(CEnv& cenv); + Value* compile(CEnv& cenv); +}; + +/// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" +struct ASTDefinition : public ASTCall { + ASTDefinition(const ASTTuple& t) : ASTCall(t) {} + void constrain(TEnv& tenv) const; + void lift(CEnv& cenv); + Value* compile(CEnv& cenv); +}; + +/// Conditional special form, e.g. "(if cond thenexp elseexp)" +struct ASTIf : public ASTCall { + ASTIf(const ASTTuple& t) : ASTCall(t) {} + void constrain(TEnv& tenv) const; + Value* compile(CEnv& cenv); +}; + +/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" +struct ASTPrimitive : public ASTCall { + ASTPrimitive(const ASTTuple& t, int o, int a=0) : ASTCall(t), op(o), arg(a) {} + void constrain(TEnv& tenv) const; + Value* compile(CEnv& cenv); + unsigned op; + unsigned arg; +}; + + +/*************************************************************************** + * Parser - S-Expressions (SExp) -> AST Nodes (AST) * + ***************************************************************************/ + +/// LLVM Operation +struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; }; + +typedef Op UD; // User Data argument for parse functions + +// Parse Time Environment (symbol table) +struct PEnv : private map<const string, ASTSymbol*> { + typedef AST* (*PF)(PEnv&, const SExp::List&, UD); // Parse Function + struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; + map<string, Parser> parsers; + void reg(const string& s, const Parser& p) { + parsers.insert(make_pair(sym(s)->str(), p)); + } + const Parser* parser(const string& s) const { + map<string, Parser>::const_iterator i = parsers.find(s); + return (i != parsers.end()) ? &i->second : NULL; + } + ASTSymbol* sym(const string& s) { + const const_iterator i = find(s); + return ((i != end()) + ? i->second + : insert(make_pair(s, new ASTSymbol(s))).first->second); + } +}; + +/// The fundamental parser method +static AST* parseExpression(PEnv& penv, const SExp& exp); + +static ASTTuple +pmap(PEnv& penv, const SExp::List& l) +{ + ASTTuple ret(l.size()); + size_t n = 0; + FOREACH(SExp::List::const_iterator, i, l) + ret[n++] = parseExpression(penv, *i); + return ret; +} + +static AST* +parseExpression(PEnv& penv, const SExp& exp) +{ + if (exp.type == SExp::LIST) { + if (exp.list.empty()) throw SyntaxError("Call to empty list"); + if (exp.list.front().type == SExp::ATOM) { + const PEnv::Parser* handler = penv.parser(exp.list.front().atom); + if (handler) // Dispatch to parse function + return handler->pf(penv, exp.list, handler->ud); + } + return new ASTCall(pmap(penv, exp.list)); // Parse as regular call + } else if (isdigit(exp.atom[0])) { + if (exp.atom.find('.') == string::npos) + return new ASTLiteral<int32_t>(strtol(exp.atom.c_str(), NULL, 10)); + else + return new ASTLiteral<float>(strtod(exp.atom.c_str(), NULL)); + } + return penv.sym(exp.atom); +} + +// Special forms + +static AST* +parseIf(PEnv& penv, const SExp::List& c, UD) + { return new ASTIf(pmap(penv, c)); } + +static AST* +parseDef(PEnv& penv, const SExp::List& c, UD) + { return new ASTDefinition(pmap(penv, c)); } + +static AST* +parsePrim(PEnv& penv, const SExp::List& c, UD data) + { return new ASTPrimitive(pmap(penv, c), data.op, data.arg); } + +static AST* +parseFn(PEnv& penv, const SExp::List& c, UD) +{ + SExp::List::const_iterator a = c.begin(); ++a; + return new ASTClosure( + new ASTTuple(pmap(penv, (*a++).list)), + parseExpression(penv, *a++)); +} + + +/*************************************************************************** + * Generic Lexical Environment * + ***************************************************************************/ + +template<typename K, typename V> +struct Env : public list< map<K,V> > { + typedef map<K,V> Frame; + Env() : list<Frame>(1) {} + void push_front() { list<Frame>::push_front(Frame()); } + const V& def(const K& k, const V& v) { + typename Frame::iterator existing = this->front().find(k); + if (existing != this->front().end() && existing->second != v) + throw SyntaxError("Redefinition"); + return (this->front()[k] = v); + } + V* ref(const K& name) { + typename Frame::iterator s; + for (typename Env::iterator i = this->begin(); i != this->end(); ++i) + if ((s = i ->find(name)) != i->end()) + return &s->second; + return 0; + } +}; + + +/*************************************************************************** + * Typing * + ***************************************************************************/ + +struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; + +struct TSubst : public map<AType*, AType*> { + TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); } +}; + +/// Type-Time Environment +struct TEnv { + TEnv(PEnv& p) : penv(p), varID(1) {} + typedef map<const AST*, AType*> Types; + typedef list< pair<AType*, AType*> > Constraints; + AType* var() { return new AType(varID++); } + AType* type(const AST* ast) { + Types::iterator t = types.find(ast); + return (t != types.end()) ? t->second : (types[ast] = var()); + } + AType* named(const string& name) const { + Types::const_iterator i = namedTypes.find(penv.sym(name)); + if (i == namedTypes.end()) throw TypeError("Unknown named type"); + return i->second; + } + void name(const string& name, const Type* type) { + ASTSymbol* sym = penv.sym(name); + namedTypes[sym] = new AType(penv.sym(name), type); + } + void constrain(const AST* o, AType* t) { + constraints.push_back(make_pair(type(o), t)); + } + void solve() { apply(unify(constraints)); } + void apply(const TSubst& substs); + static TSubst unify(const Constraints& c); + PEnv& penv; + Types types; + Types namedTypes; + Constraints constraints; + unsigned varID; +}; + +#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) + +void +ASTTuple::constrain(TEnv& tenv) const +{ + AType* t = new AType(ASTTuple()); + FOREACH(const_iterator, p, *this) { + (*p)->constrain(tenv); + t->push_back(tenv.type(*p)); + } + tenv.constrain(tenv.type(this), t); +} + +void +ASTClosure::constrain(TEnv& tenv) const +{ + prot->constrain(tenv); + at(2)->constrain(tenv); + AType* bodyT = tenv.type(at(2)); + tenv.constrain(this, new AType( + ASTTuple(tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0))); +} + +void +ASTCall::constrain(TEnv& tenv) const +{ + FOREACH(const_iterator, p, *this) + (*p)->constrain(tenv); + AType* retT = tenv.type(this); + tenv.constrain(at(0), new AType(ASTTuple( + tenv.penv.sym("Fn"), tenv.var(), retT, NULL))); +} + +void +ASTDefinition::constrain(TEnv& tenv) const +{ + if (size() != 3) + throw SyntaxError("\"def\" not passed 2 arguments"); + if (!dynamic_cast<const ASTSymbol*>(at(1))) + throw SyntaxError("\"def\" name is not a symbol"); + FOREACH(const_iterator, p, *this) + (*p)->constrain(tenv); + AType* tvar = tenv.type(this); + tenv.constrain(at(1), tvar); + tenv.constrain(at(2), tvar); +} + +void +ASTIf::constrain(TEnv& tenv) const +{ + FOREACH(const_iterator, p, *this) + (*p)->constrain(tenv); + AType* tvar = tenv.type(this); + tenv.constrain(at(1), tenv.named("Bool")); + tenv.constrain(at(2), tvar); + tenv.constrain(at(3), tvar); +} + +void +ASTPrimitive::constrain(TEnv& tenv) const +{ + FOREACH(const_iterator, p, *this) + (*p)->constrain(tenv); + if (OP_IS_A(op, Instruction::BinaryOps)) { + if (size() <= 1) throw SyntaxError("Primitive call with 0 args"); + AType* tvar = tenv.type(this); + for (size_t i = 1; i < size(); ++i) + tenv.constrain(at(i), tvar); + } else if (op == Instruction::ICmp) { + if (size() != 3) throw SyntaxError("Comparison call with != 2 args"); + tenv.constrain(at(1), tenv.type(at(2))); + tenv.constrain(this, tenv.named("Bool")); + } else { + throw TypeError("Unknown primitive"); + } +} + +static void +substitute(ASTTuple* tup, AST* from, AST* to) +{ + if (!tup) return; + for (size_t i = 0; i < tup->size(); ++i) + if (*tup->at(i) == *from) + tup->at(i) = to; + else + substitute(dynamic_cast<ASTTuple*>(tup->at(i)), from, to); +} + +bool +ASTTuple::contains(AST* child) const +{ + if (*this == *child) return true; + FOREACH(const_iterator, p, *this) + if (**p == *child || (*p)->contains(child)) + return true; + return false; +} + +TSubst +compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1 +{ + TSubst r; + for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) { + TSubst::const_iterator d = delta.find(g->second); + r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second)); + } + for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) { + if (gamma.find(d->first) == gamma.end()) + r.insert(*d); + } + return r; +} + +void +substConstraints(TEnv::Constraints& constraints, AType* s, AType* t) +{ + for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) { + TEnv::Constraints::iterator next = c; ++next; + if (*c->first == *s) c->first = t; + if (*c->second == *s) c->second = t; + substitute(c->first, s, t); + substitute(c->second, s, t); + c = next; + } +} + +TSubst +TEnv::unify(const Constraints& constraints) // TAPL 22.4 +{ + if (constraints.empty()) return TSubst(); + AType* s = constraints.begin()->first; + AType* t = constraints.begin()->second; + Constraints cp = constraints; + cp.erase(cp.begin()); + + if (*s == *t) { + return unify(cp); + } else if (s->var && !t->contains(s)) { + substConstraints(cp, s, t); + return compose(unify(cp), TSubst(s, t)); + } else if (t->var && !s->contains(t)) { + substConstraints(cp, t, s); + return compose(unify(cp), TSubst(t, s)); + } else if (s->isForm("Fn") && t->isForm("Fn")) { + AType* s1 = dynamic_cast<AType*>(s->at(1)); + AType* t1 = dynamic_cast<AType*>(t->at(1)); + AType* s2 = dynamic_cast<AType*>(s->at(2)); + AType* t2 = dynamic_cast<AType*>(t->at(2)); + assert(s1 && t1 && s2 && t2); + cp.push_back(make_pair(s1, t1)); + cp.push_back(make_pair(s2, t2)); + return unify(cp); + } else { + throw TypeError("Type unification failed"); + } +} + +void +TEnv::apply(const TSubst& substs) +{ + FOREACH(TSubst::const_iterator, s, substs) + FOREACH(Types::iterator, t, types) + if (*t->second == *s->first) + t->second = s->second; +} + + +/*************************************************************************** + * Code Generation * + ***************************************************************************/ + +struct CompileError : public Error { CompileError(const char* m) : Error(m) {} }; + +class PEnv; + +/// Compile-Time Environment +struct CEnv { + CEnv(PEnv& p, Module* m, const TargetData* target) + : penv(p), tenv(p), module(m), emp(module), opt(&emp), symID(0) + { + // Set up the optimizer pipeline: + opt.add(new TargetData(*target)); // Register target arch + opt.add(createInstructionCombiningPass()); // Simple optimizations + opt.add(createReassociatePass()); // Reassociate expressions + opt.add(createGVNPass()); // Eliminate Common Subexpressions + opt.add(createCFGSimplificationPass()); // Simplify control flow + } + string gensym(const char* base="_") { + ostringstream s; s << base << symID++; return s.str(); + } + void push() { code.push_front(); vals.push_front(); } + void pop() { code.pop_front(); vals.pop_front(); } + Value* compile(AST* obj) { + Value** v = vals.ref(obj); + return (v) ? *v : vals.def(obj, obj->compile(*this)); + } + void precompile(AST* obj, Value* value) { + assert(!vals.ref(obj)); + vals.def(obj, value); + } + void optimise(Function& f) { verifyFunction(f); opt.run(f); } + typedef Env<const AST*, AST*> Code; + typedef Env<const AST*, Value*> Vals; + PEnv& penv; + TEnv tenv; + IRBuilder<> builder; + Module* module; + ExistingModuleProvider emp; + FunctionPassManager opt; + unsigned symID; + Code code; + Vals vals; +}; + +#define LITERAL(CT, NAME, COMPILED) \ +template<> Value* \ +ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \ +template<> void \ +ASTLiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } + +/// Literal template instantiations +LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)); +LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)); +LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); + +static Function* +compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT) +{ + Function::LinkageTypes linkage = Function::ExternalLinkage; + + vector<const Type*> cprot; + for (size_t i = 0; i < prot.size(); ++i) { + const AType* at = cenv.tenv.type(prot.at(i)); + if (!at->ctype || at->var) throw CompileError("Parameter is untyped"); + cprot.push_back(at->ctype); + } + + if (!retT) throw CompileError("Return is untyped"); + FunctionType* fT = FunctionType::get(retT, cprot, false); + Function* f = Function::Create(fT, linkage, name, cenv.module); + + if (f->getName() != name) { + f->eraseFromParent(); + throw CompileError("Function redefined"); + } + + // Set argument names in generated code + Function::arg_iterator a = f->arg_begin(); + for (size_t i = 0; i != prot.size(); ++a, ++i) + a->setName(prot.at(i)->str()); + + return f; +} + +Value* +ASTSymbol::compile(CEnv& cenv) +{ + AST** c = cenv.code.ref(this); + if (!c) throw SyntaxError((string("Undefined symbol: ") + cppstr).c_str()); + return cenv.vals.def(this, cenv.compile(*c)); +} + +void +ASTClosure::lift(CEnv& cenv) +{ + if (cenv.tenv.type(at(2))->var) + throw CompileError("Closure with untyped body lifted"); + for (size_t i = 0; i < prot->size(); ++i) + if (cenv.tenv.type(prot->at(i))->var) + throw CompileError("Closure with untyped parameter lifted"); + + assert(!func); + cenv.push(); + + // Write function declaration + Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(at(2))->ctype); + BasicBlock* bb = BasicBlock::Create("entry", f); + cenv.builder.SetInsertPoint(bb); + + // Bind argument values in CEnv + vector<Value*> args; + const_iterator p = prot->begin(); + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) + cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a); + + // Write function body + try { + cenv.precompile(this, f); // Define our value first for recursion + Value* retVal = cenv.compile(at(2)); + cenv.builder.CreateRet(retVal); // Finish function + cenv.optimise(*f); + func = f; + } catch (exception e) { + f->eraseFromParent(); // Error reading body, remove function + throw e; + } + + assert(func); + cenv.pop(); +} + +Value* +ASTClosure::compile(CEnv& cenv) +{ + assert(func); + return func; // Function was already compiled in the lifting pass +} + +void +ASTCall::lift(CEnv& cenv) +{ + ASTClosure* c = dynamic_cast<ASTClosure*>(at(0)); + if (!c) { + AST** val = cenv.code.ref(at(0)); + c = (val) ? dynamic_cast<ASTClosure*>(*val) : c; + } + + // Lift arguments + for (size_t i = 1; i < size(); ++i) + at(i)->lift(cenv); + + if (!c) return; + + // Extend environment with bound and typed parameters + cenv.push(); + if (c->prot->size() != size() - 1) + throw CompileError("Call to closure with mismatched arguments"); + + for (size_t i = 1; i < size(); ++i) + cenv.code.def(c->prot->at(i-1), at(i)); + + at(0)->lift(cenv); // Lift called closure + cenv.pop(); // Restore environment +} + +Value* +ASTCall::compile(CEnv& cenv) +{ + ASTClosure* c = dynamic_cast<ASTClosure*>(at(0)); + if (!c) { + AST** val = cenv.code.ref(at(0)); + c = (val) ? dynamic_cast<ASTClosure*>(*val) : c; + } + + assert(c); + Function* f = dynamic_cast<Function*>(cenv.compile(c)); + if (!f) throw CompileError("Callee failed to compile"); + + vector<Value*> params(size() - 1); + for (size_t i = 1; i < size(); ++i) + params[i-1] = cenv.compile(at(i)); + + return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); +} + +void +ASTDefinition::lift(CEnv& cenv) +{ + cenv.code.def((ASTSymbol*)at(1), at(2)); // Define first for recursion + at(2)->lift(cenv); +} + +Value* +ASTDefinition::compile(CEnv& cenv) +{ + return cenv.compile(at(2)); +} + +Value* +ASTIf::compile(CEnv& cenv) +{ + typedef vector< pair<Value*, BasicBlock*> > Branches; + Function* parent = cenv.builder.GetInsertBlock()->getParent(); + BasicBlock* mergeBB = BasicBlock::Create("endif"); + BasicBlock* nextBB = NULL; + Branches branches; + ostringstream ss; + for (size_t i = 1; i < size() - 1; i += 2) { + Value* condV = cenv.compile(at(i)); + + ss.str(""); ss << "then" << ((i + 1) / 2); + BasicBlock* thenBB = BasicBlock::Create(ss.str()); + + ss.str(""); ss << "else" << ((i + 1) / 2); + nextBB = BasicBlock::Create(ss.str()); + + cenv.builder.CreateCondBr(condV, thenBB, nextBB); + + // Emit then block for this condition + parent->getBasicBlockList().push_back(thenBB); + cenv.builder.SetInsertPoint(thenBB); + Value* thenV = cenv.compile(at(i + 1)); + cenv.builder.CreateBr(mergeBB); + branches.push_back(make_pair(thenV, cenv.builder.GetInsertBlock())); + + parent->getBasicBlockList().push_back(nextBB); + cenv.builder.SetInsertPoint(nextBB); + } + + // Emit else block + cenv.builder.SetInsertPoint(nextBB); + Value* elseV = cenv.compile(at(size() - 1)); + cenv.builder.CreateBr(mergeBB); + branches.push_back(make_pair(elseV, cenv.builder.GetInsertBlock())); + + // Emit merge block (Phi node) + parent->getBasicBlockList().push_back(mergeBB); + cenv.builder.SetInsertPoint(mergeBB); + PHINode* pn = cenv.builder.CreatePHI(cenv.tenv.type(this)->ctype, "ifval"); + + for (Branches::iterator i = branches.begin(); i != branches.end(); ++i) + pn->addIncoming(i->first, i->second); + + return pn; +} + +Value* +ASTPrimitive::compile(CEnv& cenv) +{ + if (size() < 3) throw SyntaxError("Too few arguments"); + Value* a = cenv.compile(at(1)); + Value* b = cenv.compile(at(2)); + + if (OP_IS_A(op, Instruction::BinaryOps)) { + const Instruction::BinaryOps bo = (Instruction::BinaryOps)op; + if (size() == 2) + return cenv.compile(at(1)); + Value* val = cenv.builder.CreateBinOp(bo, a, b); + for (size_t i = 3; i < size(); ++i) + val = cenv.builder.CreateBinOp(bo, val, cenv.compile(at(i))); + return val; + } else if (op == Instruction::ICmp) { + bool isInt = cenv.tenv.type(at(1))->str() == "(Int)"; + if (isInt) { + return cenv.builder.CreateICmp((CmpInst::Predicate)arg, a, b); + } else { + // Translate to floating point operation + switch (arg) { + case CmpInst::ICMP_EQ: arg = CmpInst::FCMP_OEQ; break; + case CmpInst::ICMP_NE: arg = CmpInst::FCMP_ONE; break; + case CmpInst::ICMP_SGT: arg = CmpInst::FCMP_OGT; break; + case CmpInst::ICMP_SGE: arg = CmpInst::FCMP_OGE; break; + case CmpInst::ICMP_SLT: arg = CmpInst::FCMP_OLT; break; + case CmpInst::ICMP_SLE: arg = CmpInst::FCMP_OLE; break; + default: throw CompileError("Unknown primitive"); + } + return cenv.builder.CreateFCmp((CmpInst::Predicate)arg, a, b); + } + } + throw CompileError("Unknown primitive"); +} + + +/*************************************************************************** + * REPL * + ***************************************************************************/ + +int +main() +{ +#define PRIM(O, A) PEnv::Parser(parsePrim, Op(Instruction:: O, A)) + PEnv penv; + penv.reg("fn", PEnv::Parser(parseFn, Op())); + penv.reg("if", PEnv::Parser(parseIf, Op())); + penv.reg("def", PEnv::Parser(parseDef, Op())); + penv.reg("+", PRIM(Add, 0)); + penv.reg("-", PRIM(Sub, 0)); + penv.reg("*", PRIM(Mul, 0)); + penv.reg("/", PRIM(FDiv, 0)); + penv.reg("%", PRIM(FRem, 0)); + penv.reg("&", PRIM(And, 0)); + penv.reg("|", PRIM(Or, 0)); + penv.reg("^", PRIM(Xor, 0)); + penv.reg("=", PRIM(ICmp, CmpInst::ICMP_EQ)); + penv.reg("!=", PRIM(ICmp, CmpInst::ICMP_NE)); + penv.reg(">", PRIM(ICmp, CmpInst::ICMP_SGT)); + penv.reg(">=", PRIM(ICmp, CmpInst::ICMP_SGE)); + penv.reg("<", PRIM(ICmp, CmpInst::ICMP_SLT)); + penv.reg("<=", PRIM(ICmp, CmpInst::ICMP_SLE)); + + Module* module = new Module("repl"); + ExecutionEngine* engine = ExecutionEngine::create(module); + CEnv cenv(penv, module, engine->getTargetData()); + + cenv.tenv.name("Bool", Type::Int1Ty); + cenv.tenv.name("Int", Type::Int32Ty); + cenv.tenv.name("Float", Type::FloatTy); + cenv.code.def(penv.sym("true"), new ASTLiteral<bool>(true)); + cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false)); + + while (1) { + std::cout << "() "; + std::cout.flush(); + SExp exp = readExpression(std::cin); + if (exp.type == SExp::LIST && exp.list.empty()) + break; + + try { + AST* body = parseExpression(penv, exp); // Parse input + body->constrain(cenv.tenv); // Constrain types + cenv.tenv.solve(); // Solve and apply type constraints + + AType* bodyT = cenv.tenv.type(body); + if (!bodyT) throw TypeError("REPL call to untyped body"); + if (bodyT->var) throw TypeError("REPL call to variable typed body"); + + body->lift(cenv); + + if (bodyT->ctype) { + // Create anonymous function to insert code into. + ASTTuple* prot = new ASTTuple(); + Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype); + BasicBlock* bb = BasicBlock::Create("entry", f); + cenv.builder.SetInsertPoint(bb); + try { + Value* retVal = cenv.compile(body); + cenv.builder.CreateRet(retVal); // Finish function + cenv.optimise(*f); + } catch (SyntaxError e) { + f->eraseFromParent(); // Error reading body, remove function + throw e; + } + void* fp = engine->getPointerToFunction(f); + if (bodyT->ctype == Type::Int32Ty) + std::cout << "; " << ((int32_t (*)())fp)(); + else if (bodyT->ctype == Type::FloatTy) + std::cout << "; " << ((float (*)())fp)(); + else if (bodyT->ctype == Type::Int1Ty) + std::cout << "; " << ((bool (*)())fp)(); + } else { + Value* val = cenv.compile(body); + std::cout << "; " << val; + } + std::cout << " : " << cenv.tenv.type(body)->str() << endl; + + } catch (Error e) { + std::cerr << "Error: " << e.what() << endl; + } + } + + std::cout << endl << "Generated code:" << endl; + module->dump(); + return 0; +} + |