/* Tuplr: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <map>
#include <sstream>
#include <boost/format.hpp>
#include "llvm/Analysis/Verifier.h"
#include "llvm/Assembly/AsmAnnotationWriter.h"
#include "llvm/DerivedTypes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/Instructions.h"
#include "llvm/Module.h"
#include "llvm/ModuleProvider.h"
#include "llvm/PassManager.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Transforms/Scalar.h"
#include "tuplr.hpp"

using namespace llvm;
using namespace std;
using boost::format;

inline Value*    LLVal(CValue     v) { return static_cast<Value*>(v); }
inline Function* LLFunc(CFunction f) { return static_cast<Function*>(f); }

struct LLVMEngine {
	LLVMEngine();
	Module*          module;
	ExecutionEngine* engine;
	IRBuilder<>      builder;
};

static const Type*
lltype(const AType* t)
{
	switch (t->kind) {
	case AType::VAR:
		throw Error((format("non-compilable type `%1%'") % t->str()).str(), t->loc);
		return NULL;
	case AType::PRIM:
		if (t->at(0)->str() == "Bool")  return Type::Int1Ty;
		if (t->at(0)->str() == "Int")   return Type::Int32Ty;
		if (t->at(0)->str() == "Float") return Type::FloatTy;
		throw Error(string("Unknown primitive type `") + t->str() + "'");
	case AType::EXPR:
		if (t->at(0)->str() == "Pair") {
			vector<const Type*> types;
			for (size_t i = 1; i < t->size(); ++i)
				types.push_back(lltype(t->at(i)->to<AType*>()));
			return PointerType::get(StructType::get(types, false), 0);
		}
	}
	return NULL; // not reached
}

static LLVMEngine*
llengine(CEnv& cenv)
{
	return reinterpret_cast<LLVMEngine*>(cenv.engine());
}

LLVMEngine::LLVMEngine()
	: module(new Module("tuplr"))
	, engine(ExecutionEngine::create(module))
{
}

struct CEnv::PImpl {
	PImpl(LLVMEngine* e) : engine(e), module(e->module), emp(module), opt(&emp)
	{
		// Set up the optimizer pipeline:
		const TargetData* target = engine->engine->getTargetData();
		opt.add(new TargetData(*target));          // Register target arch
		opt.add(createInstructionCombiningPass()); // Simple optimizations
		opt.add(createReassociatePass());          // Reassociate expressions
		opt.add(createGVNPass());                  // Eliminate Common Subexpressions
		opt.add(createCFGSimplificationPass());    // Simplify control flow
	}

	LLVMEngine*            engine;
	Module*                module;
	ExistingModuleProvider emp;
	FunctionPassManager    opt;
};

CEnv::CEnv(PEnv& p, TEnv& t, CEngine e, ostream& os, ostream& es)
	: out(os), err(es), penv(p), tenv(t), symID(0), alloc(0), _pimpl(new PImpl((LLVMEngine*)e))
{
}

CEnv::~CEnv()
{
	delete _pimpl;
}

CEngine
CEnv::engine()
{
	return _pimpl->engine;
}

CValue
CEnv::compile(AST* obj)
{
	CValue* v = vals.ref(obj);
	return (v && *v) ? *v : vals.def(obj, obj->compile(*this));
}

void
CEnv::optimise(CFunction f)
{
	verifyFunction(*static_cast<Function*>(f));
	_pimpl->opt.run(*static_cast<Function*>(f));
}

void
CEnv::write(std::ostream& os)
{
	AssemblyAnnotationWriter writer;
	_pimpl->engine->module->print(os, &writer);
}

#define LITERAL(CT, NAME, COMPILED) \
template<> CValue \
ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
template<> void \
ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) const { c.constrain(tenv, this, tenv.named(NAME)); }

/// Literal template instantiations
LITERAL(int32_t, "Int",   ConstantInt::get(Type::Int32Ty, val, true))
LITERAL(float,   "Float", ConstantFP::get(Type::FloatTy, val))
LITERAL(bool,    "Bool",  ConstantInt::get(Type::Int1Ty, val, false))

static Function*
compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& protT,
		const vector<string> argNames=vector<string>())
{
	Function::LinkageTypes linkage = Function::ExternalLinkage;

	vector<const Type*> cprot;
	for (size_t i = 0; i < protT.size(); ++i) {
		AType* at = protT.at(i)->as<AType*>();
		if (!lltype(at)) throw Error("function parameter is untyped");
		cprot.push_back(lltype(at));
	}

	if (!retT) throw Error("function return is untyped");
	FunctionType* fT = FunctionType::get(static_cast<const Type*>(retT), cprot, false);
	Function*     f  = Function::Create(fT, linkage, name, llengine(cenv)->module);

	if (f->getName() != name) {
		f->eraseFromParent();
		throw Error("function redefined");
	}

	// Set argument names in generated code
	Function::arg_iterator a = f->arg_begin();
	if (!argNames.empty())
		for (size_t i = 0; i != protT.size(); ++a, ++i)
			a->setName(argNames.at(i));

	BasicBlock* bb = BasicBlock::Create("entry", f);
	llengine(cenv)->builder.SetInsertPoint(bb);

	return f;
}


/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

CValue
ASymbol::compile(CEnv& cenv)
{
	return cenv.vals.ref(this);
}

void
AClosure::lift(CEnv& cenv)
{
	AType* type = cenv.type(this);
	if (funcs.find(type) || !type->concrete())
		return;

	ATuple* protT = type->at(1)->as<ATuple*>();
	vector<AType*> argsT;
	for (size_t i = 0; i < protT->size(); ++i)
		argsT.push_back(protT->at(i)->as<AType*>());

	liftCall(cenv, argsT);
}

void
AClosure::liftCall(CEnv& cenv, const vector<AType*>& argsT)
{
	TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(this);
	assert(gt != cenv.tenv.genericTypes.end());
	AType* genericType = new AType(*gt->second);

	AType* thisType = genericType;
	Subst  argsSubst;
	if (!thisType->concrete()) {
		// Find type and build substitution
		assert(argsT.size() == prot()->size());
		ATuple* genericProtT = genericType->at(1)->as<ATuple*>();
		for (size_t i = 0; i < argsT.size(); ++i)
			argsSubst[*genericProtT->at(i)->to<AType*>()] = argsT.at(i)->to<AType*>();
		thisType = argsSubst.apply(genericType)->as<AType*>();
		if (!thisType->concrete())
			throw Error("unable to resolve concrete type for function", loc);
	} else {
		thisType = genericType;
	}

	if (funcs.find(thisType))
		return;

	ATuple* protT = thisType->at(1)->as<ATuple*>();

	// Write function declaration
	string name = this->name == "" ? cenv.gensym("_fn") : this->name;
	Function* f = compileFunction(cenv, name,
			lltype(thisType->at(thisType->size()-1)->to<AType*>()),
			*protT);

	cenv.push();
	Subst oldSubst = cenv.tsubst;
	cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, *subst));

	// Bind argument values in CEnv
	vector<Value*> args;
	const_iterator p = prot()->begin();
	size_t i = 0;
	for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
		cenv.def((*p)->as<ASymbol*>(), *p, protT->at(i++)->as<AType*>(), &*a);

	// Write function body
	try {
		// Define value first for recursion
		cenv.precompile(this, f);
		funcs.push_back(make_pair(thisType, f));
		CValue retVal = NULL;
		for (size_t i = 2; i < size(); ++i)
			retVal = cenv.compile(at(i));
		llengine(cenv)->builder.CreateRet(LLVal(retVal)); // Finish function
		cenv.optimise(LLFunc(f));
	} catch (Error& e) {
		f->eraseFromParent(); // Error reading body, remove function
		cenv.pop();
		throw e;
	}
	cenv.tsubst = oldSubst;
	cenv.pop();
}

CValue
AClosure::compile(CEnv& cenv)
{
	return NULL;
}

void
ACall::lift(CEnv& cenv)
{
	AClosure*      c = cenv.tenv.resolve(at(0))->to<AClosure*>();
	vector<AType*> argsT;

	// Lift arguments
	for (size_t i = 1; i < size(); ++i) {
		at(i)->lift(cenv);
		argsT.push_back(cenv.type(at(i)));
	}

	if (!c) return; // Primitive

	if (c->prot()->size() < size() - 1)
		throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc);
	if (c->prot()->size() > size() - 1)
		throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc);

	c->liftCall(cenv, argsT); // Lift called closure
}

CValue
ACall::compile(CEnv& cenv)
{
	AClosure* c = cenv.tenv.resolve(at(0))->to<AClosure*>();

	if (!c) return NULL; // Primitive

	AType* protT = new AType(loc, NULL);
	for (size_t i = 1; i < size(); ++i)
		protT->push_back(cenv.type(at(i)));

	TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(c);
	assert(gt != cenv.tenv.genericTypes.end());
	AType*    fnT = new AType(loc, cenv.penv.sym("Fn"), protT, cenv.type(this), 0);
	Function* f   = (Function*)c->funcs.find(fnT);
	if (!f) throw Error((format("callee failed to compile for type %1%") % fnT->str()).str(), loc);

	vector<Value*> params(size() - 1);
	for (size_t i = 1; i < size(); ++i)
		params[i-1] = LLVal(cenv.compile(at(i)));

	return llengine(cenv)->builder.CreateCall(f, params.begin(), params.end());
}

void
ADefinition::lift(CEnv& cenv)
{
	// Define stub first for recursion
	cenv.def(sym(), at(2), cenv.type(at(2)), NULL);
	AClosure* c = at(2)->to<AClosure*>();
	if (c)
		c->name = sym()->str();
	at(2)->lift(cenv);
}

CValue
ADefinition::compile(CEnv& cenv)
{
	// Define stub first for recursion
	cenv.def(sym(), at(2), cenv.type(at(2)), NULL);
	CValue val = cenv.compile(at(size() - 1));
	cenv.vals.def(sym(), val);
	return val;
}

CValue
AIf::compile(CEnv& cenv)
{
	typedef vector< pair<Value*, BasicBlock*> > Branches;
	Function*   parent  = llengine(cenv)->builder.GetInsertBlock()->getParent();
	BasicBlock* mergeBB = BasicBlock::Create("endif");
	BasicBlock* nextBB  = NULL;
	Branches    branches;
	for (size_t i = 1; i < size() - 1; i += 2) {
		Value*      condV  = LLVal(cenv.compile(at(i)));
		BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str());

		nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str());

		llengine(cenv)->builder.CreateCondBr(condV, thenBB, nextBB);

		// Emit then block for this condition
		parent->getBasicBlockList().push_back(thenBB);
		llengine(cenv)->builder.SetInsertPoint(thenBB);
		Value* thenV = LLVal(cenv.compile(at(i+1)));
		llengine(cenv)->builder.CreateBr(mergeBB);
		branches.push_back(make_pair(thenV, llengine(cenv)->builder.GetInsertBlock()));

		parent->getBasicBlockList().push_back(nextBB);
		llengine(cenv)->builder.SetInsertPoint(nextBB);
	}

	// Emit final else block
	llengine(cenv)->builder.SetInsertPoint(nextBB);
	Value* elseV = LLVal(cenv.compile(at(size() - 1)));
	llengine(cenv)->builder.CreateBr(mergeBB);
	branches.push_back(make_pair(elseV, llengine(cenv)->builder.GetInsertBlock()));

	// Emit merge block (Phi node)
	parent->getBasicBlockList().push_back(mergeBB);
	llengine(cenv)->builder.SetInsertPoint(mergeBB);
	PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.type(this)), "ifval");

	FOREACH(Branches::iterator, i, branches)
		pn->addIncoming(i->first, i->second);

	return pn;
}

CValue
APrimitive::compile(CEnv& cenv)
{
	Value*       a       = LLVal(cenv.compile(at(1)));
	Value*       b       = LLVal(cenv.compile(at(2)));
	bool         isFloat = cenv.type(at(1))->str() == "Float";
	const string n       = at(0)->to<ASymbol*>()->str();

	// Binary arithmetic operations
	Instruction::BinaryOps op = (Instruction::BinaryOps)0;
	if (n == "+")   op = Instruction::Add;
	if (n == "-")   op = Instruction::Sub;
	if (n == "*")   op = Instruction::Mul;
	if (n == "and") op = Instruction::And;
	if (n == "or")  op = Instruction::Or;
	if (n == "xor") op = Instruction::Xor;
	if (n == "/")   op = isFloat ? Instruction::FDiv : Instruction::SDiv;
	if (n == "%")   op = isFloat ? Instruction::FRem : Instruction::SRem;
	if (op != 0) {
		Value* val = llengine(cenv)->builder.CreateBinOp(op, a, b);
		for (size_t i = 3; i < size(); ++i)
			val = llengine(cenv)->builder.CreateBinOp(op, val, LLVal(cenv.compile(at(i))));
		return val;
	}

	// Comparison operations
	CmpInst::Predicate pred = (CmpInst::Predicate)0;
	if (n == "=")  pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ;
	if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ;
	if (n == ">")  pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT;
	if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE;
	if (n == "<")  pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT;
	if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE;
	if (pred != 0) {
		if (isFloat)
			return llengine(cenv)->builder.CreateFCmp(pred, a, b);
		else
			return llengine(cenv)->builder.CreateICmp(pred, a, b);
	}

	throw Error("unknown primitive", loc);
}

AType*
AConsCall::functionType(CEnv& cenv)
{
	ATuple* protTypes = new ATuple(loc, cenv.type(at(1)), cenv.type(at(2)), 0);
	AType* cellType = new AType(loc,
			cenv.penv.sym("Pair"), cenv.type(at(1)), cenv.type(at(2)), 0);
	return new AType(at(0)->loc, cenv.penv.sym("Fn"), protTypes, cellType, 0);
}

void
AConsCall::lift(CEnv& cenv)
{
	AType* funcType = functionType(cenv);
	if (funcs.find(functionType(cenv)))
		return;

	ACall::lift(cenv);

	ATuple* protT = new ATuple(loc, cenv.type(at(1)), cenv.type(at(2)), 0);

	vector<const Type*> types;
	size_t sz = 0;
	for (size_t i = 1; i < size(); ++i) {
		const Type* t = lltype(cenv.type(at(i)));
		types.push_back(t);
		sz += t->getPrimitiveSizeInBits();
	}
	sz = (sz % 8 == 0) ? sz / 8 : sz / 8 + 1;

	llvm::IRBuilder<>& builder = llengine(cenv)->builder;

	StructType* sT = StructType::get(types, false);
	Type*       pT = PointerType::get(sT, 0);

	// Write function declaration
	vector<string> argNames;
	argNames.push_back("car");
	argNames.push_back("cdr");
	Function* func = compileFunction(cenv, cenv.gensym("cons"), pT, *protT, argNames);

	Value* mem  = builder.CreateCall(LLVal(cenv.alloc), ConstantInt::get(Type::Int32Ty, sz), "mem");
	Value* cell = builder.CreateBitCast(mem, pT, "cell");
	Value* s    = builder.CreateGEP(cell, ConstantInt::get(Type::Int32Ty, 0), "pair");
	Value* carP = builder.CreateStructGEP(s, 0, "car");
	Value* cdrP = builder.CreateStructGEP(s, 1, "cdr");

	Function::arg_iterator ai = func->arg_begin();
	Value& carArg = *ai++;
	Value& cdrArg = *ai++;
	builder.CreateStore(&carArg, carP);
	builder.CreateStore(&cdrArg, cdrP);
	builder.CreateRet(cell);

	cenv.optimise(func);
	funcs.push_back(make_pair(funcType, func));
}

CValue
AConsCall::compile(CEnv& cenv)
{
	vector<Value*> params(size() - 1);
	for (size_t i = 1; i < size(); ++i)
		params[i-1] = LLVal(cenv.compile(at(i)));

	return llengine(cenv)->builder.CreateCall(LLFunc(funcs.find(functionType(cenv))),
			params.begin(), params.end());
}

CValue
ACarCall::compile(CEnv& cenv)
{
	AST*   arg  = cenv.tenv.resolve(at(1));
	Value* sP   = LLVal(cenv.compile(arg));
	Value* s    = llengine(cenv)->builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair");
	Value* carP = llengine(cenv)->builder.CreateStructGEP(s, 0, "car");
	return llengine(cenv)->builder.CreateLoad(carP);
}

CValue
ACdrCall::compile(CEnv& cenv)
{
	AST*   arg  = cenv.tenv.resolve(at(1));
	Value* sP   = LLVal(cenv.compile(arg));
	Value* s    = llengine(cenv)->builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair");
	Value* cdrP = llengine(cenv)->builder.CreateStructGEP(s, 1, "cdr");
	return llengine(cenv)->builder.CreateLoad(cdrP);
}


/***************************************************************************
 * EVAL/REPL                                                               *
 ***************************************************************************/

const string
call(AType* retT, void* fp)
{
	std::stringstream ss;
	if (lltype(retT) == Type::Int32Ty)
		ss << ((int32_t (*)())fp)();
	else if (lltype(retT) == Type::FloatTy)
		ss << showpoint << ((float (*)())fp)();
	else if (lltype(retT) == Type::Int1Ty)
		ss << (((bool (*)())fp)() ? "#t" : "#f");
	else
		ss << ((void* (*)())fp)();
	return ss.str();
}

int
eval(CEnv& cenv, const string& name, istream& is)
{
	AST*   result     = NULL;
	AType* resultType = NULL;
	list< pair<SExp, AST*> > exprs;
	Cursor cursor(name);
	try {
		while (true) {
			SExp exp = readExpression(cursor, is);
			if (exp.type == SExp::LIST && exp.empty())
				break;

			result = cenv.penv.parse(exp); // Parse input
			Constraints c;
			result->constrain(cenv.tenv, c); // Constrain types
			cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints
			resultType = cenv.type(result);
			result->lift(cenv); // Lift functions
			exprs.push_back(make_pair(exp, result));
		}

		const Type* ctype = lltype(resultType);
		if (!ctype) throw Error("body has non-compilable type", cursor);

		// Create function for top-level of program
		Function* f = compileFunction(cenv, "main", ctype, ATuple(cursor));

		// Compile all expressions into it
		Value* val = NULL;
		for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i)
			val = LLVal(cenv.compile(i->second));

		// Finish function
		llengine(cenv)->builder.CreateRet(val);
		cenv.optimise(f);

		cenv.out << call(resultType, llengine(cenv)->engine->getPointerToFunction(f))
		    << " : " << resultType << endl;
	} catch (Error& e) {
		cenv.err << e.what() << endl;
		return 1;
	}
	return 0;
}

int
repl(CEnv& cenv)
{
	while (1) {
		cenv.out << "() ";
		cenv.out.flush();
		Cursor cursor("(stdin)");

		try {
			SExp exp = readExpression(cursor, std::cin);
			if (exp.type == SExp::LIST && exp.empty())
				break;

			AST* body = cenv.penv.parse(exp); // Parse input
			Constraints c;
			body->constrain(cenv.tenv, c); // Constrain types

			Subst oldSubst = cenv.tsubst;
			cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints

			AType* bodyT = cenv.type(body);
			if (!bodyT) throw Error("call to untyped body", cursor);

			body->lift(cenv);

			if (lltype(bodyT)) {
				// Create anonymous function to insert code into
				Function* f = compileFunction(cenv, cenv.gensym("_repl"), lltype(bodyT), ATuple(cursor));
				try {
					Value* retVal = LLVal(cenv.compile(body));
					llengine(cenv)->builder.CreateRet(retVal); // Finish function
					cenv.optimise(f);
				} catch (Error& e) {
					f->eraseFromParent(); // Error reading body, remove function
					throw e;
				}
				cenv.out << call(bodyT, llengine(cenv)->engine->getPointerToFunction(f));
			} else {
				cenv.out << ";   " << cenv.compile(body);
			}
			cenv.out << " : " << cenv.type(body) << endl;
			cenv.tsubst = oldSubst;
		} catch (Error& e) {
			cenv.err << e.what() << endl;
		}
	}
	return 0;
}

CEnv*
newCenv(PEnv& penv, TEnv& tenv)
{
	LLVMEngine* engine = new LLVMEngine();
	CEnv*       cenv   = new CEnv(penv, tenv, engine);

	// Host provided allocation primitive prototypes
	std::vector<const Type*> argsT(1, Type::Int32Ty);
	FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false);
	cenv->alloc = Function::Create(funcT, Function::ExternalLinkage, "malloc", engine->module);

	return cenv;
}