/* A Trivial LLVM LISP
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Parts from the Kaleidoscope tutorial <http://llvm.org/docs/tutorial/>
 * by Chris Lattner and Erick Tryzelaar
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with This program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <stdarg.h>
#include <iostream>
#include <list>
#include <map>
#include <sstream>
#include <stack>
#include <string>
#include <vector>
#include "llvm/Analysis/Verifier.h"
#include "llvm/DerivedTypes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/Instructions.h"
#include "llvm/Module.h"
#include "llvm/ModuleProvider.h"
#include "llvm/PassManager.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Transforms/Scalar.h"

#define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i)

using namespace llvm;
using namespace std;

struct Error : public std::exception {
	Error(const char* m) : msg(m) {}
	const char* what() const throw() { return msg; }
	const char* msg;
};

template<typename A>
struct Exp { // ::= Atom | (Exp*)
	Exp()           : type(LIST)          {}
	Exp(const A& a) : type(ATOM), atom(a) {}
	enum { ATOM, LIST } type;
	typedef std::vector< Exp<A> > List;
	A    atom;
	List list;
};


/***************************************************************************
 * S-Expression Lexer :: text -> S-Expressions (SExp)                      *
 ***************************************************************************/

struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} };
typedef Exp<string> SExp;

static SExp
readExpression(std::istream& in)
{
#define PUSH(s, t)  { if (t != "") { s.top().list.push_back(t); t = ""; } }
#define YIELD(s, t) { if (s.empty()) return t; else PUSH(s, t) }
	stack<SExp> stk;
	string      tok;
	while (char ch = in.get()) {
		switch (ch) {
		case EOF:
			return SExp();
		case ' ': case '\t': case '\n':
			if (tok != "") YIELD(stk, tok);
			break;
		case '"':
			do { tok.push_back(ch); } while ((ch = in.get()) != '"');
			YIELD(stk, tok + '"');
			break;
		case '(':
			stk.push(SExp());
			break;
		case ')':
			switch (stk.size()) {
			case 0:
				throw SyntaxError("Unexpected ')'");
			case 1:
				PUSH(stk, tok);
				return stk.top();
			default:
				PUSH(stk, tok);
				SExp l = stk.top();
				stk.pop();
				stk.top().list.push_back(l);
			}
			break;
		default:
			tok += ch;
		}
	}
	switch (stk.size()) {
	case 0:  return tok;
	case 1:  return stk.top();
	default: throw  SyntaxError("Missing ')'");
	}
	return SExp();
}


/***************************************************************************
 * Abstract Syntax Tree                                                    *
 ***************************************************************************/

struct TEnv;  ///< Type-Time Environment
struct CEnv;  ///< Compile-Time Environment
struct AType; ///< Abstract Type

/// Base class for all AST nodes
struct AST {
	virtual ~AST() {}
	virtual bool   contains(AST* child)     const { return false; }
	virtual bool   operator!=(const AST& o) const { return !operator==(o); }
	virtual bool   operator==(const AST& o) const = 0;
	virtual string str()                    const = 0;
	virtual void   constrain(TEnv& tenv)    const {}
	virtual void   lift(CEnv& cenv)               {}
	virtual Value* compile(CEnv& cenv)            = 0;
};

/// Literal
template<typename VT>
struct ASTLiteral : public AST {
	ASTLiteral(VT v) : val(v) {}
	bool operator==(const AST& rhs) const {
		const ASTLiteral<VT>* r = dynamic_cast<const ASTLiteral<VT>*>(&rhs);
		return r && val == r->val;
	}
	string str() const { ostringstream s; s << val; return s.str(); }
	void   constrain(TEnv& tenv) const;
	Value* compile(CEnv& cenv);
	const VT val;
};

/// Symbol, e.g. "a"
struct ASTSymbol : public AST {
	ASTSymbol(const string& s) : cppstr(s) {}
	bool   operator==(const AST& rhs) const { return this == &rhs; }
	string str()                      const { return cppstr; }
	Value* compile(CEnv& cenv);
private:
	const string cppstr;
};

typedef vector<AST*> TupV;

/// Tuple (heterogeneous sequence of known length), e.g. "(a b c)"
struct ASTTuple : public AST {
	ASTTuple(const TupV& t=TupV()) : tup(t) {}
	string str() const {
		string ret = "(";
		for (size_t i = 0; i != tup.size(); ++i)
			ret += tup[i]->str() + ((i != tup.size() - 1) ? " " : "");
		return ret + ")";
	}
	bool operator==(const AST& rhs) const {
		const ASTTuple* rt = dynamic_cast<const ASTTuple*>(&rhs);
		if (!rt) return false;
		if (rt->tup.size() != tup.size()) return false;
		TupV::const_iterator l = tup.begin();
		FOREACH(TupV::const_iterator, r, rt->tup) {
			AST* mine = *l++;
			AST* other = *r;
			if (!(*mine == *other))
				return false;
		}
		return true;
	}
	void lift(CEnv& cenv) {
		FOREACH(TupV::iterator, t, tup)
			(*t)->lift(cenv);
	}
	bool isForm(const string& f) { return !tup.empty() && tup[0]->str() == f; }
	bool   contains(AST* child) const;
	void   constrain(TEnv& tenv) const;
	Value* compile(CEnv& cenv) { return NULL; }
	TupV tup;
};

static TupV
tuple(AST* ast, ...)
{
	TupV tup(1, ast);
	va_list args;
	va_start(args, ast);
	for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
		tup.push_back(a);
	va_end(args);
	return tup;
}

/// Type Expression ::= (TName TExpr*) | ?Num
struct AType : public ASTTuple {
	AType(const TupV& t) : ASTTuple(t), var(false), ctype(0) {}
	AType(unsigned i) : var(true), ctype(0), id(i) {}
	AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) {
		tup.push_back(n);
	}
	string str() const {
		if (var) {
			ostringstream s; s << "?" << id; return s.str();
		} else {
			return ASTTuple::str();
		}
	}
	void   constrain(TEnv& tenv) const {}
	Value* compile(CEnv& cenv) { return NULL; }
	bool   concrete() const {
		if (var) return false;
		FOREACH(TupV::const_iterator, t, tup) {
			AType* kid = dynamic_cast<AType*>(*t);
			if (kid && !kid->concrete())
				return false;
		}
		return true;
	}
	bool operator==(const AST& rhs) const {
		const AType* rt = dynamic_cast<const AType*>(&rhs);
		if (!rt)
			return false;
		else if (var && rt->var)
			return id == rt->id;
		else if (!var && !rt->var)
			return ASTTuple::operator==(rhs);
		return false;
	}
	bool        var;
	const Type* ctype;
	unsigned    id;
};

/// Closure (first-class function with captured lexical bindings)
struct ASTClosure : public ASTTuple {
	ASTClosure(ASTTuple* p, AST* b)
		: ASTTuple(tuple(0, p, b)), prot(p), func(0) {}
	bool operator==(const AST& rhs) const { return this == &rhs; }
	string str() const { ostringstream s; s << this; return s.str(); }
	void   constrain(TEnv& tenv) const;
	void   lift(CEnv& cenv);
	Value* compile(CEnv& cenv);
	ASTTuple* const prot;
private:
	Function* func;
};
	
/// Function call/application, e.g. "(func arg1 arg2)"
struct ASTCall : public ASTTuple {
	ASTCall(const TupV& t) : ASTTuple(t) {}
	void   constrain(TEnv& tenv) const;
	void   lift(CEnv& cenv);
	Value* compile(CEnv& cenv);
};

/// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))"
struct ASTDefinition : public ASTCall {
	ASTDefinition(const TupV& t) : ASTCall(t) {}
	void   constrain(TEnv& tenv) const;
	void   lift(CEnv& cenv);
	Value* compile(CEnv& cenv);
};

/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct ASTIf : public ASTCall {
	ASTIf(const TupV& t) : ASTCall(t) {}
	void   constrain(TEnv& tenv) const;
	Value* compile(CEnv& cenv);
};

/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct ASTPrimitive : public ASTCall {
	ASTPrimitive(const TupV& t, int o, int a=0) : ASTCall(t), op(o), arg(a) {}
	void   constrain(TEnv& tenv) const;
	Value* compile(CEnv& cenv);
	unsigned op;
	unsigned arg;
};


/***************************************************************************
 * Parser - S-Expressions (SExp) -> AST Nodes (AST)                        *
 ***************************************************************************/

/// LLVM Operation
struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; };

typedef Op UD; // User Data argument for parse functions

// Parse Time Environment (symbol table)
struct PEnv : private map<const string, ASTSymbol*> {
	typedef AST* (*PF)(PEnv&, const SExp::List&, UD); // Parse Function
	struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; };
	map<string, Parser> parsers;
	void reg(const string& s, const Parser& p) {
		parsers.insert(make_pair(sym(s)->str(), p));
	}
	const Parser* parser(const string& s) const {
		map<string, Parser>::const_iterator i = parsers.find(s);
		return (i != parsers.end()) ? &i->second : NULL;
	}
	ASTSymbol* sym(const string& s) {
		const const_iterator i = find(s);
		return ((i != end())
			? i->second
			: insert(make_pair(s, new ASTSymbol(s))).first->second);
	}
};

/// The fundamental parser method
static AST* parseExpression(PEnv& penv, const SExp& exp);

static TupV
pmap(PEnv& penv, const SExp::List& l)
{
	TupV ret(l.size());
	size_t n = 0;
	FOREACH(SExp::List::const_iterator, i, l)
		ret[n++] = parseExpression(penv, *i);
	return ret;
}

static AST*
parseExpression(PEnv& penv, const SExp& exp)
{
	if (exp.type == SExp::LIST) {
		if (exp.list.empty()) throw SyntaxError("Call to empty list");
		if (exp.list.front().type == SExp::ATOM) {
			const PEnv::Parser* handler = penv.parser(exp.list.front().atom);
			if (handler) // Dispatch to parse function
				return handler->pf(penv, exp.list, handler->ud);
		}
		return new ASTCall(pmap(penv, exp.list)); // Parse as regular call
	} else if (isdigit(exp.atom[0])) {
		if (exp.atom.find('.') == string::npos)
			return new ASTLiteral<int32_t>(strtol(exp.atom.c_str(), NULL, 10));
		else
			return new ASTLiteral<float>(strtod(exp.atom.c_str(), NULL));
	}
	return penv.sym(exp.atom);
}

// Special forms

static AST*
parseIf(PEnv& penv, const SExp::List& c, UD)
	{ return new ASTIf(pmap(penv, c)); }

static AST*
parseDef(PEnv& penv, const SExp::List& c, UD)
	{ return new ASTDefinition(pmap(penv, c)); }

static AST*
parsePrim(PEnv& penv, const SExp::List& c, UD data)
	{ return new ASTPrimitive(pmap(penv, c), data.op, data.arg); }

static AST*
parseFn(PEnv& penv, const SExp::List& c, UD)
{
	SExp::List::const_iterator a = c.begin(); ++a;
	return new ASTClosure(
			new ASTTuple(pmap(penv, (*a++).list)),
			parseExpression(penv, *a++));
}


/***************************************************************************
 * Generic Lexical Environment                                             *
 ***************************************************************************/

template<typename K, typename V>
struct Env : public list< map<K,V> > {
	typedef map<K,V> Frame;
	Env() : list<Frame>(1) {}
	void push_front() { list<Frame>::push_front(Frame()); }
	const V& def(const K& k, const V& v) {
		if (this->front().find(k) != this->front().end())
			throw SyntaxError("Redefinition");
		return (this->front()[k] = v);
	}
	V* ref(const K& name) {
		typename Frame::iterator s;
		for (typename Env::iterator i = this->begin(); i != this->end(); ++i)
			if ((s = i ->find(name)) != i->end())
				return &s->second;
		return 0;
	}
};


/***************************************************************************
 * Typing                                                                  *
 ***************************************************************************/

struct TypeError : public Error { TypeError (const char* m) : Error(m) {} };

struct TSubst : public map<AType*, AType*> {
	TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); }
};

/// Type-Time Environment
struct TEnv {
	TEnv(PEnv& p) : penv(p), varID(1) {}
	typedef map<const AST*, AType*>  Types;
	typedef list< pair<AType*, AType*> > Constraints;
	AType* var() { return new AType(varID++); }
	AType* type(const AST* ast) {
		Types::iterator t = types.find(ast);
		return (t != types.end()) ? t->second : (types[ast] = var());
	}
	AType* named(const string& name) const {
		Types::const_iterator i = namedTypes.find(penv.sym(name));
		if (i == namedTypes.end()) throw TypeError("Unknown named type");
		return i->second;
	}
	void name(const string& name, const Type* type) {
		ASTSymbol* sym = penv.sym(name);
		namedTypes[sym] = new AType(penv.sym(name), type);
	}
	void constrain(const AST* o, AType* t) {
		constraints.push_back(make_pair(type(o), t));
	}
	void          solve() { apply(unify(constraints)); }
	void          apply(const TSubst& substs);
	static TSubst unify(const Constraints& c);
	PEnv&       penv;
	Types       types;
	Types       namedTypes;
	Constraints constraints;
	unsigned    varID;
};

#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End)

void
ASTTuple::constrain(TEnv& tenv) const
{
	TupV texp;
	FOREACH(TupV::const_iterator, p, tup) {
		(*p)->constrain(tenv);
		texp.push_back(tenv.type(*p));
	}
	AType* t = tenv.type(this);
	t->var = false;
	t->tup = texp;
}

void
ASTClosure::constrain(TEnv& tenv) const
{
	prot->constrain(tenv);
	tup[2]->constrain(tenv);
	AType* bodyT = tenv.type(tup[2]);
	tenv.constrain(this, new AType(tuple(
			tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0)));
}

void
ASTCall::constrain(TEnv& tenv) const
{
	FOREACH(TupV::const_iterator, p, tup)
		(*p)->constrain(tenv);
	AType* retT = tenv.type(this);
	TupV texp = tuple(tenv.penv.sym("Fn"), tenv.var(), retT, NULL);
	tenv.constrain(tup[0], new AType(texp));
}

void
ASTDefinition::constrain(TEnv& tenv) const
{
	if (tup.size() != 3)
		throw SyntaxError("\"def\" not passed 2 arguments");
	if (!dynamic_cast<const ASTSymbol*>(tup[1]))
		throw SyntaxError("\"def\" name is not a symbol");
	FOREACH(TupV::const_iterator, p, tup)
		(*p)->constrain(tenv);
	AType* tvar = tenv.type(this);
	tenv.constrain(tup[1], tvar);
	tenv.constrain(tup[2], tvar);
}

void
ASTIf::constrain(TEnv& tenv) const
{
	FOREACH(TupV::const_iterator, p, tup)
		(*p)->constrain(tenv);
	AType* tvar = tenv.type(this);
	tenv.constrain(tup[1], tenv.named("Bool"));
	tenv.constrain(tup[2], tvar);
	tenv.constrain(tup[3], tvar);
}

void
ASTPrimitive::constrain(TEnv& tenv) const
{
	FOREACH(TupV::const_iterator, p, tup)
		(*p)->constrain(tenv);
	if (OP_IS_A(op, Instruction::BinaryOps)) {
		if (tup.size() <= 1) throw SyntaxError("Primitive call with 0 args");
		AType* tvar = tenv.type(this);
		for (size_t i = 1; i < tup.size(); ++i)
			tenv.constrain(tup[i], tvar);
	} else if (op == Instruction::ICmp) {
		if (tup.size() != 3) throw SyntaxError("Comparison call with != 2 args");
		tenv.constrain(tup[1], tenv.type(tup[2]));
		tenv.constrain(this, tenv.named("Bool"));
	} else {
		throw TypeError("Unknown primitive");
	}
}

static void
substitute(ASTTuple* tup, AST* from, AST* to)
{
	if (!tup) return;
	for (size_t i = 0; i < tup->tup.size(); ++i)
		if (*tup->tup[i] == *from)
			tup->tup[i] = to;
		else
			substitute(dynamic_cast<ASTTuple*>(tup->tup[i]), from, to);
}

bool
ASTTuple::contains(AST* child) const
{
	if (*this == *child) return true;
	FOREACH(TupV::const_iterator, p, tup)
		if (**p == *child || (*p)->contains(child))
			return true;
	return false;
}

TSubst
compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1
{
	TSubst r;
	for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
		TSubst::const_iterator d = delta.find(g->second);
		r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second));
	}
	for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
		if (gamma.find(d->first) == gamma.end())
			r.insert(*d);
	}
	return r;
}

void
substConstraints(TEnv::Constraints& constraints, AType* s, AType* t)
{
	for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) {
		TEnv::Constraints::iterator next = c; ++next;
		if (*c->first  == *s) c->first  = t;
		if (*c->second == *s) c->second = t;
		substitute(c->first, s, t);
		substitute(c->second, s, t);
		c = next;
	}
}

TSubst
TEnv::unify(const Constraints& constraints) // TAPL 22.4
{
	if (constraints.empty()) return TSubst();
	AType*      s  = constraints.begin()->first;
	AType*      t  = constraints.begin()->second;
	Constraints cp = constraints;
	cp.erase(cp.begin());

	if (*s == *t) {
		return unify(cp);
	} else if (s->var && !t->contains(s)) {
		substConstraints(cp, s, t);
		return compose(unify(cp), TSubst(s, t));
	} else if (t->var && !s->contains(t)) {
		substConstraints(cp, t, s);
		return compose(unify(cp), TSubst(t, s));
	} else if (s->isForm("Fn") && t->isForm("Fn")) {
		AType* s1 = dynamic_cast<AType*>(s->tup[1]);
		AType* t1 = dynamic_cast<AType*>(t->tup[1]);
		AType* s2 = dynamic_cast<AType*>(s->tup[2]);
		AType* t2 = dynamic_cast<AType*>(t->tup[2]);
		assert(s1 && t1 && s2 && t2);
		cp.push_back(make_pair(s1, t1));
		cp.push_back(make_pair(s2, t2));
		return unify(cp);
	} else {
		throw TypeError("Type unification failed");
	}
}

void
TEnv::apply(const TSubst& substs)
{
	FOREACH(TSubst::const_iterator, s, substs)
		FOREACH(Types::iterator, t, types)
			if (*t->second == *s->first)
				t->second = s->second;
}


/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

struct CompileError : public Error { CompileError(const char* m) : Error(m) {} };

class PEnv;

/// Compile-Time Environment
struct CEnv {
	CEnv(PEnv& p, Module* m, const TargetData* target)
		: penv(p), tenv(p), module(m), emp(module), fpm(&emp), symID(0)
	{ 
		// Set up the optimizer pipeline:
		fpm.add(new TargetData(*target)); // Register target arch
		fpm.add(createInstructionCombiningPass()); // Simple optimizations
		fpm.add(createReassociatePass()); // Reassociate expressions
		fpm.add(createGVNPass()); // Eleminate Common Subexpressions
		fpm.add(createCFGSimplificationPass()); // Simplify control flow
	}
	string gensym(const char* base="_") {
		ostringstream s; s << base << symID++; return s.str();
	}
	void push() { code.push_front(); vals.push_front(); }
	void pop()  { code.pop_front();  vals.pop_front(); }
	Value* compile(AST* obj) {
		Value** v = vals.ref(obj);
		return (v) ? *v : vals.def(obj, obj->compile(*this));
	}
	void precompile(AST* obj, Value* value) {
		assert(!vals.ref(obj));
		vals.def(obj, value);
	}
	typedef Env<const AST*, AST*>   Code;
	typedef Env<const AST*, Value*> Vals;
	PEnv&                  penv;
	TEnv                   tenv;
	IRBuilder<>            builder;
	Module*                module;
	ExistingModuleProvider emp;
	FunctionPassManager    fpm;
	unsigned               symID;
	Code                   code;
	Vals                   vals;
};

#define LITERAL(CT, NAME, COMPILED) \
template<> Value* \
ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
template<> void \
ASTLiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); }

/// Literal template instantiations
LITERAL(int32_t, "Int",   ConstantInt::get(Type::Int32Ty, val, true));
LITERAL(float,   "Float", ConstantFP::get(Type::FloatTy, val));
LITERAL(bool,    "Bool",  ConstantInt::get(Type::Int1Ty, val, false));

static Function*
compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT)
{
	Function::LinkageTypes linkage = Function::ExternalLinkage;

	const TupV& texp = cenv.tenv.type(&prot)->tup;
	vector<const Type*> cprot;
	for (size_t i = 0; i < texp.size(); ++i) {
		const AType* at = dynamic_cast<AType*>(texp[i]); assert(at);
		if (!at->ctype) throw CompileError("Parameter is untyped");
		cprot.push_back(at->ctype);
	}

	if (!retT) throw CompileError("Return is untyped");
	FunctionType* fT = FunctionType::get(retT, cprot, false);
	Function*     f  = Function::Create(fT, linkage, name, cenv.module);

	if (f->getName() != name) {
		f->eraseFromParent();
		throw CompileError("Function redefined");
	}

	// Set argument names in generated code
	Function::arg_iterator a = f->arg_begin();
	for (size_t i = 0; i != prot.tup.size(); ++a, ++i)
		a->setName(prot.tup[i]->str());

	return f;
}

Value*
ASTSymbol::compile(CEnv& cenv)
{
	Value** v = cenv.vals.ref(this);
	if (v) return *v;

	AST** c = cenv.code.ref(this);
	if (c) {
		Value* v = cenv.compile(*c);
		cenv.vals.def(this, v);
		return v;
	}

	throw SyntaxError((string("Undefined symbol '") + cppstr + "'").c_str());
}

void
ASTClosure::lift(CEnv& cenv)
{
	// Can't lift a closure with variable types (lift later when called)
	if (cenv.tenv.type(tup[2])->var) return;
	for (size_t i = 0; i < prot->tup.size(); ++i)
		if (cenv.tenv.type(prot->tup[i])->var)
			return;

	assert(!func);
	cenv.push();

	// Write function declaration
	Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(tup[2])->ctype);
	BasicBlock* bb = BasicBlock::Create("entry", f);
	cenv.builder.SetInsertPoint(bb);
	
	// Bind argument values in CEnv
	vector<Value*> args;
	TupV::const_iterator p = prot->tup.begin();
	for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
		cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a);

	// Write function body
	try {
		cenv.precompile(this, f); // Define our value first for recursion
		Value* retVal = cenv.compile(tup[2]);
		cenv.builder.CreateRet(retVal); // Finish function
		verifyFunction(*f); // Validate generated code
		cenv.fpm.run(*f); // Optimize function
		func = f;
	} catch (exception e) {
		f->eraseFromParent(); // Error reading body, remove function
		throw e;
	}

	assert(func);
	cenv.pop();
}

Value*
ASTClosure::compile(CEnv& cenv)
{
	assert(func);
	return func; // Function was already compiled in the lifting pass
}

void
ASTCall::lift(CEnv& cenv)
{
	ASTClosure* c = dynamic_cast<ASTClosure*>(tup[0]);
	if (!c) {
		AST** val = cenv.code.ref(tup[0]);
		c = (val) ? dynamic_cast<ASTClosure*>(*val) : c;
	}

	// Lift arguments
	for (size_t i = 1; i < tup.size(); ++i)
		tup[i]->lift(cenv);
	
	if (!c) return;

	// Extend environment with bound and typed parameters
	cenv.push();
	if (c->prot->tup.size() != tup.size() - 1)
		throw CompileError("Call to closure with mismatched arguments");

	for (size_t i = 1; i < tup.size(); ++i)
		cenv.code.def(c->prot->tup[i-1], tup[i]);

	tup[0]->lift(cenv); // Lift called closure
	cenv.pop(); // Restore environment
}

Value*
ASTCall::compile(CEnv& cenv)
{
	ASTClosure* c = dynamic_cast<ASTClosure*>(tup[0]);
	if (!c) {
		AST** val = cenv.code.ref(tup[0]);
		c = (val) ? dynamic_cast<ASTClosure*>(*val) : c;
	}

	if (!c) throw CompileError("Call to non-closure");
	Value* v = cenv.compile(c);
	if (!v) throw CompileError("Callee failed to compile");
	Function* f = dynamic_cast<Function*>(cenv.compile(c));
	if (!f) throw CompileError("Callee compiled to non-function");

	vector<Value*> params;
	for (size_t i = 1; i < tup.size(); ++i)
		params.push_back(cenv.compile(tup[i]));

	return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp");
}

void
ASTDefinition::lift(CEnv& cenv)
{
	cenv.code.def((ASTSymbol*)tup[1], tup[2]); // Define first for recursion
	tup[2]->lift(cenv);
}

Value*
ASTDefinition::compile(CEnv& cenv)
{
	return cenv.compile(tup[2]);
}

Value*
ASTIf::compile(CEnv& cenv)
{
	Value*    condV  = cenv.compile(tup[1]);
	Function* parent = cenv.builder.GetInsertBlock()->getParent();

	// Create blocks for the then and else cases.
	// Insert the 'then' block at the end of the function.
	BasicBlock* thenBB  = BasicBlock::Create("then", parent);
	BasicBlock* elseBB  = BasicBlock::Create("else");
	BasicBlock* mergeBB = BasicBlock::Create("ifcont");

	cenv.builder.CreateCondBr(condV, thenBB, elseBB);

	// Emit then block
	cenv.builder.SetInsertPoint(thenBB);
	Value* thenV = cenv.compile(tup[2]); // Can change current block, so...
	cenv.builder.CreateBr(mergeBB);
	thenBB = cenv.builder.GetInsertBlock(); // ... update thenBB afterwards

	// Emit else block
	parent->getBasicBlockList().push_back(elseBB);
	cenv.builder.SetInsertPoint(elseBB);
	Value* elseV = cenv.compile(tup[3]); // Can change current block, so...
	cenv.builder.CreateBr(mergeBB);
	elseBB = cenv.builder.GetInsertBlock(); // ... update elseBB afterwards

	// Emit merge block
	parent->getBasicBlockList().push_back(mergeBB);
	cenv.builder.SetInsertPoint(mergeBB);
	PHINode* pn = cenv.builder.CreatePHI(cenv.tenv.type(this)->ctype, "iftmp");

	pn->addIncoming(thenV, thenBB);
	pn->addIncoming(elseV, elseBB);
	return pn;
}

Value*
ASTPrimitive::compile(CEnv& cenv)
{
	if (tup.size() < 3) throw SyntaxError("Too few arguments");
	Value* a = cenv.compile(tup[1]);
	Value* b = cenv.compile(tup[2]);

	if (OP_IS_A(op, Instruction::BinaryOps)) {
		const Instruction::BinaryOps bo = (Instruction::BinaryOps)op;
		if (tup.size() == 2)
			return cenv.compile(tup[1]);
		Value* val = cenv.builder.CreateBinOp(bo, a, b);
		for (size_t i = 3; i < tup.size(); ++i)
			val = cenv.builder.CreateBinOp(bo, val, cenv.compile(tup[i]));
		return val;
	} else if (op == Instruction::ICmp) {
		bool isInt = cenv.tenv.type(tup[1])->str() == "(Int)";
		if (isInt) {
			return cenv.builder.CreateICmp((CmpInst::Predicate)arg, a, b);
		} else {
			// Translate to floating point operation
			switch (arg) {
			case CmpInst::ICMP_EQ:  arg = CmpInst::FCMP_OEQ; break;
			case CmpInst::ICMP_NE:  arg = CmpInst::FCMP_ONE; break;
			case CmpInst::ICMP_SGT: arg = CmpInst::FCMP_OGT; break;
			case CmpInst::ICMP_SGE: arg = CmpInst::FCMP_OGE; break;
			case CmpInst::ICMP_SLT: arg = CmpInst::FCMP_OLT; break;
			case CmpInst::ICMP_SLE: arg = CmpInst::FCMP_OLE; break;
			default: throw CompileError("Unknown primitive");
			}
			return cenv.builder.CreateFCmp((CmpInst::Predicate)arg, a, b);
		}
	}
	throw CompileError("Unknown primitive");
}


/***************************************************************************
 * REPL                                                                    *
 ***************************************************************************/

int
main()
{
#define PRIM(O, A) PEnv::Parser(parsePrim, Op(Instruction:: O, A))
	PEnv penv;
	penv.reg("fn",  PEnv::Parser(parseFn,   Op()));
	penv.reg("if",  PEnv::Parser(parseIf,   Op()));
	penv.reg("def", PEnv::Parser(parseDef,  Op()));
	penv.reg("+",   PRIM(Add,  0));
	penv.reg("-",   PRIM(Sub,  0));
	penv.reg("*",   PRIM(Mul,  0));
	penv.reg("/",   PRIM(FDiv, 0));
	penv.reg("%",   PRIM(FRem, 0));
	penv.reg("&",   PRIM(And,  0));
	penv.reg("|",   PRIM(Or,   0));
	penv.reg("^",   PRIM(Xor,  0));
	penv.reg("=",   PRIM(ICmp, CmpInst::ICMP_EQ));
	penv.reg("!=",  PRIM(ICmp, CmpInst::ICMP_NE));
	penv.reg(">",   PRIM(ICmp, CmpInst::ICMP_SGT));
	penv.reg(">=",  PRIM(ICmp, CmpInst::ICMP_SGE));
	penv.reg("<",   PRIM(ICmp, CmpInst::ICMP_SLT));
	penv.reg("<=",  PRIM(ICmp, CmpInst::ICMP_SLE));

	Module*          module = new Module("repl");
	ExecutionEngine* engine = ExecutionEngine::create(module);
	CEnv             cenv(penv, module, engine->getTargetData());
	
	cenv.tenv.name("Bool",  Type::Int1Ty);
	cenv.tenv.name("Int",   Type::Int32Ty);
	cenv.tenv.name("Float", Type::FloatTy);
	cenv.code.def(penv.sym("true"),  new ASTLiteral<bool>(true));
	cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false));

	while (1) {
		std::cout << "() ";
		std::cout.flush();
		SExp exp = readExpression(std::cin);
		if (exp.type == SExp::LIST && exp.list.empty())
			break;

		try {
			AST* body = parseExpression(penv, exp); // Parse input
			body->constrain(cenv.tenv); // Constrain types
			cenv.tenv.solve(); // Solve and apply type constraints

			AType* bodyT = cenv.tenv.type(body);
			if (!bodyT)     throw TypeError("REPL call to untyped body");
			if (bodyT->var) throw TypeError("REPL call to variable typed body");
			
			body->lift(cenv);
			
			if (bodyT->ctype) {
				// Create anonymous function to insert code into.
				ASTTuple* prot = new ASTTuple();
				Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype);
				BasicBlock* bb = BasicBlock::Create("entry", f);
				cenv.builder.SetInsertPoint(bb);
				try {
					Value* retVal = cenv.compile(body);
					cenv.builder.CreateRet(retVal); // Finish function
					verifyFunction(*f); // Validate generated code
					cenv.fpm.run(*f); // Optimize function
				} catch (SyntaxError e) {
					f->eraseFromParent(); // Error reading body, remove function
					throw e;
				}
				void* fp = engine->getPointerToFunction(f);
				if (bodyT->ctype == Type::Int32Ty)
					std::cout << ";  " << ((int32_t (*)())fp)();
				else if (bodyT->ctype == Type::FloatTy)
					std::cout << ";  " << ((float (*)())fp)();
				else if (bodyT->ctype == Type::Int1Ty)
					std::cout << ";  " << ((bool (*)())fp)();
			} else {
				Value* val = cenv.compile(body);
				std::cout << ";   " << val;
			}
			std::cout << " : " << cenv.tenv.type(body)->str() << endl;

		} catch (Error e) {
			std::cerr << "Error: " << e.what() << endl;
		}
	}

	std::cout << endl << "Generated code:" << endl;
	module->dump();
	return 0;
}