/* Tuplr: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <map>
#include <sstream>
#include <boost/format.hpp>
#include "llvm/Analysis/Verifier.h"
#include "llvm/Assembly/AsmAnnotationWriter.h"
#include "llvm/DerivedTypes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/Instructions.h"
#include "llvm/Module.h"
#include "llvm/ModuleProvider.h"
#include "llvm/PassManager.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Transforms/Scalar.h"
#include "tuplr.hpp"

using namespace llvm;
using namespace std;
using boost::format;

inline Value*    llVal(CValue     v) { return static_cast<Value*>(v); }
inline Function* llFunc(CFunction f) { return static_cast<Function*>(f); }

static const Type*
llType(const AType* t)
{
	if (t->kind == AType::PRIM) {
		if (t->at(0)->str() == "Bool")  return Type::Int1Ty;
		if (t->at(0)->str() == "Int")   return Type::Int32Ty;
		if (t->at(0)->str() == "Float") return Type::FloatTy;
		throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'");
	}
	return NULL; // non-primitive type
}

struct LLVMEngine {
	LLVMEngine() : module(new Module("tuplr")), engine(ExecutionEngine::create(module)) {}
	Module*          module;
	ExecutionEngine* engine;
	IRBuilder<>      builder;
};

static LLVMEngine*
llengine(CEnv& cenv)
{
	return reinterpret_cast<LLVMEngine*>(cenv.engine());
}

struct CEnv::PImpl {
	PImpl(LLVMEngine* e) : engine(e), module(e->module), emp(module), opt(&emp)
	{
		// Set up the optimizer pipeline:
		const TargetData* target = engine->engine->getTargetData();
		opt.add(new TargetData(*target));          // Register target arch
		opt.add(createInstructionCombiningPass()); // Simple optimizations
		opt.add(createReassociatePass());          // Reassociate expressions
		opt.add(createGVNPass());                  // Eliminate Common Subexpressions
		opt.add(createCFGSimplificationPass());    // Simplify control flow
	}

	LLVMEngine*            engine;
	Module*                module;
	ExistingModuleProvider emp;
	FunctionPassManager    opt;
};

CEnv::CEnv(PEnv& p, TEnv& t, CEngine e, ostream& os, ostream& es)
	: out(os), err(es), penv(p), tenv(t), symID(0), alloc(0), _pimpl(new PImpl((LLVMEngine*)e))
{
}

CEnv::~CEnv()
{
	delete _pimpl;
}

CEngine
CEnv::engine()
{
	return _pimpl->engine;
}

CValue
CEnv::compile(AST* obj)
{
	CValue* v = vals.ref(obj);
	return (v && *v) ? *v : vals.def(obj, obj->compile(*this));
}

void
CEnv::optimise(CFunction f)
{
	if (args.find("-g") != args.end())
		return;
	verifyFunction(*static_cast<Function*>(f));
	_pimpl->opt.run(*static_cast<Function*>(f));
}

void
CEnv::write(std::ostream& os)
{
	AssemblyAnnotationWriter writer;
	_pimpl->engine->module->print(os, &writer);
}

#define LITERAL(CT, NAME, COMPILED) \
template<> CValue \
ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
template<> void \
ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) const { c.constrain(tenv, this, tenv.named(NAME)); }

/// Literal template instantiations
LITERAL(int32_t, "Int",   ConstantInt::get(Type::Int32Ty, val, true))
LITERAL(float,   "Float", ConstantFP::get(Type::FloatTy, val))
LITERAL(bool,    "Bool",  ConstantInt::get(Type::Int1Ty, val, false))

CFunction
startFunction(CEnv& cenv, const std::string& name, const AType* retT, const ATuple& argsT,
		const vector<string> argNames)
{
	Function::LinkageTypes linkage = Function::ExternalLinkage;

	vector<const Type*> cprot;
	for (size_t i = 0; i < argsT.size(); ++i) {
		AType* at = argsT.at(i)->as<AType*>();
		THROW_IF(!llType(at), Cursor(), "function parameter is untyped")
		cprot.push_back(llType(at));
	}

	THROW_IF(!llType(retT), Cursor(), "function return is untyped");
	FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
	Function*     f  = Function::Create(fT, linkage, name, llengine(cenv)->module);

	if (f->getName() != name) {
		cenv.out << "DIFFERENT NAME: " << f->getName() << endl;
		/*f->eraseFromParent();
		throw Error(Cursor(), (format("function `%1%' redefined") % name).str());*/
	}

	// Set argument names in generated code
	Function::arg_iterator a = f->arg_begin();
	if (!argNames.empty())
		for (size_t i = 0; i != argsT.size(); ++a, ++i)
			a->setName(argNames.at(i));

	BasicBlock* bb = BasicBlock::Create("entry", f);
	llengine(cenv)->builder.SetInsertPoint(bb);

	return f;
}

void
finishFunction(CEnv& cenv, CFunction f, CValue ret)
{
	Value* retVal = llVal(ret);
	llengine(cenv)->builder.CreateRet(retVal);
	cenv.optimise(llFunc(f));
}

void
eraseFunction(CEnv& cenv, CFunction f)
{
	if (f)
		llFunc(f)->eraseFromParent();
}


/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

CValue
ASymbol::compile(CEnv& cenv)
{
	return cenv.vals.ref(this);
}

void
AClosure::lift(CEnv& cenv)
{
	cenv.push();
	for (const_iterator p = prot()->begin(); p != prot()->end(); ++p)
		cenv.def((*p)->as<ASymbol*>(), *p, NULL, NULL);

	// Lift body
	for (size_t i = 2; i < size(); ++i)
		at(i)->lift(cenv);

	cenv.pop();

	AType* type = cenv.type(this);
	if (funcs.find(type) || !type->concrete())
		return;

	AType* protT = type->at(1)->as<AType*>();
	liftCall(cenv, *protT);
}

void
AClosure::liftCall(CEnv& cenv, const AType& argsT)
{
	TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(this);
	assert(gt != cenv.tenv.genericTypes.end());
	AType* thisType = new AType(*gt->second);
	Subst  argsSubst;
	if (!thisType->concrete()) {
		// Build substitution to apply to generic type
		assert(argsT.size() == prot()->size());
		ATuple* genericProtT = gt->second->at(1)->as<ATuple*>();
		for (size_t i = 0; i < argsT.size(); ++i)
			argsSubst[genericProtT->at(i)->to<AType*>()] = argsT.at(i)->to<AType*>();

		// Apply substitution to get concrete type for this call
		thisType = argsSubst.apply(thisType)->as<AType*>();
		if (!thisType->concrete())
			throw Error(loc, "unable to resolve concrete type for function");
	}

	Object::pool.addRoot(thisType);
	if (funcs.find(thisType))
		return;

	ATuple* protT = thisType->at(1)->as<ATuple*>();

	vector<string> argNames;
	for (size_t i = 0; i < prot()->size(); ++i) {
		argNames.push_back(prot()->at(i)->str());
	}

	// Write function declaration
	const string name = (this->name == "") ? cenv.gensym("_fn") : this->name;
	Function* f = llFunc(startFunction(cenv, name,
			thisType->at(thisType->size()-1)->to<AType*>(),
			*protT, argNames));

	cenv.push();
	Subst oldSubst = cenv.tsubst;
	cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, subst));

//#define EXPLICIT_STACK_FRAMES 1

#ifdef EXPLICIT_STACK_FRAMES
	vector<const Type*> types;
	types.push_back(Type::Int8Ty);
	types.push_back(Type::Int8Ty);
	size_t s = 16; // stack frame size in bits
#endif

	// Bind argument values in CEnv
	vector<Value*> args;
	const_iterator p = prot()->begin();
	size_t i = 0;
	for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) {
		AType* t = protT->at(i)->as<AType*>();
		const Type* lt = llType(t);
		THROW_IF(!lt, loc, "untyped parameter\n");
		cenv.def((*p)->as<ASymbol*>(), *p, t, &*a);
#ifdef EXPLICIT_STACK_FRAMES
		types.push_back(lt);
		s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8));
#endif
	}


#ifdef EXPLICIT_STACK_FRAMES
	IRBuilder<> builder = llengine(cenv)->builder;

	// Scan out definitions
	for (size_t i = 0; i < size(); ++i) {
		ADefinition* def = at(i)->to<ADefinition*>();
		if (def) {
			const Type* lt = llType(cenv.type(def->at(2)));
			THROW_IF(!lt, loc, "untyped definition\n");
			types.push_back(lt);
			s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8));
		}
	}

	// Create stack frame
	StructType* frameT    = StructType::get(types, false);
	Value*      tag       = ConstantInt::get(Type::Int8Ty, GC::TAG_FRAME);
	Value*      frameSize = ConstantInt::get(Type::Int32Ty, s / 8);
	Value*      frame     = builder.CreateCall2(llVal(cenv.alloc), frameSize, tag, "frame");
	Value*      framePtr  = builder.CreateBitCast(frame, PointerType::get(frameT, 0));

	// Bind parameter values in stack frame
	i = 2;
	for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++i) {
		Value* v = builder.CreateStructGEP(framePtr, i, "arg");
		builder.CreateStore(&*a, v);
	}
#endif

	// Write function body
	try {
		// Define value first for recursion
		cenv.precompile(this, f);
		funcs.push_back(make_pair(thisType, f));
		CValue retVal = NULL;
		for (size_t i = 2; i < size(); ++i)
			retVal = cenv.compile(at(i));
		finishFunction(cenv, f, retVal);
	} catch (Error& e) {
		f->eraseFromParent(); // Error reading body, remove function
		cenv.pop();
		throw e;
	}
	cenv.tsubst = oldSubst;
	cenv.pop();
}

CValue
AClosure::compile(CEnv& cenv)
{
	return NULL;
}

void
ACall::lift(CEnv& cenv)
{
	AClosure* c = cenv.tenv.resolve(at(0))->to<AClosure*>();
	AType     argsT(loc, NULL);

	// Lift arguments
	for (size_t i = 1; i < size(); ++i) {
		at(i)->lift(cenv);
		argsT.push_back(cenv.type(at(i)));
	}

	if (!c) return; // Primitive

	if (c->prot()->size() < size() - 1)
		throw Error(loc, (format("too many arguments to function `%1%'") % at(0)->str()).str());
	if (c->prot()->size() > size() - 1)
		throw Error(loc, (format("too few arguments to function `%1%'") % at(0)->str()).str());

	c->liftCall(cenv, argsT); // Lift called closure
}

CValue
ACall::compile(CEnv& cenv)
{
	AClosure* c = cenv.tenv.resolve(at(0))->to<AClosure*>();

	if (!c) return NULL; // Primitive

	AType protT(loc, NULL);
	vector<const Type*> types;
	for (size_t i = 1; i < size(); ++i) {
		protT.push_back(cenv.type(at(i)));
		types.push_back(llType(cenv.type(at(i))));
	}

	TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(c);
	assert(gt != cenv.tenv.genericTypes.end());
	AType fnT(loc, cenv.penv.sym("Fn"), &protT, cenv.type(this), 0);
	Function* f = (Function*)c->funcs.find(&fnT);
	THROW_IF(!f, loc, (format("callee failed to compile for type %1%") % fnT.str()).str());

	vector<Value*> params(size() - 1);
	for (size_t i = 0; i < types.size(); ++i)
		params[i] = llVal(cenv.compile(at(i+1)));

	return llengine(cenv)->builder.CreateCall(f, params.begin(), params.end());
}

void
ADefinition::lift(CEnv& cenv)
{
	// Define stub first for recursion
	cenv.def(sym(), at(2), cenv.type(at(2)), NULL);
	AClosure* c = at(2)->to<AClosure*>();
	if (c)
		c->name = sym()->str();
	at(2)->lift(cenv);
}

CValue
ADefinition::compile(CEnv& cenv)
{
	// Define stub first for recursion
	cenv.def(sym(), at(2), cenv.type(at(2)), NULL);
	CValue val = cenv.compile(at(size() - 1));
	cenv.vals.def(sym(), val);
	return val;
}

CValue
AIf::compile(CEnv& cenv)
{
	typedef vector< pair<Value*, BasicBlock*> > Branches;
	Function*   parent  = llengine(cenv)->builder.GetInsertBlock()->getParent();
	BasicBlock* mergeBB = BasicBlock::Create("endif");
	BasicBlock* nextBB  = NULL;
	Branches    branches;
	for (size_t i = 1; i < size() - 1; i += 2) {
		Value*      condV  = llVal(cenv.compile(at(i)));
		BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str());

		nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str());

		llengine(cenv)->builder.CreateCondBr(condV, thenBB, nextBB);

		// Emit then block for this condition
		parent->getBasicBlockList().push_back(thenBB);
		llengine(cenv)->builder.SetInsertPoint(thenBB);
		Value* thenV = llVal(cenv.compile(at(i+1)));
		llengine(cenv)->builder.CreateBr(mergeBB);
		branches.push_back(make_pair(thenV, llengine(cenv)->builder.GetInsertBlock()));

		parent->getBasicBlockList().push_back(nextBB);
		llengine(cenv)->builder.SetInsertPoint(nextBB);
	}

	// Emit final else block
	llengine(cenv)->builder.SetInsertPoint(nextBB);
	Value* elseV = llVal(cenv.compile(at(size() - 1)));
	llengine(cenv)->builder.CreateBr(mergeBB);
	branches.push_back(make_pair(elseV, llengine(cenv)->builder.GetInsertBlock()));

	// Emit merge block (Phi node)
	parent->getBasicBlockList().push_back(mergeBB);
	llengine(cenv)->builder.SetInsertPoint(mergeBB);
	PHINode* pn = llengine(cenv)->builder.CreatePHI(llType(cenv.type(this)), "ifval");

	FOREACH(Branches::iterator, i, branches)
		pn->addIncoming(i->first, i->second);

	return pn;
}

CValue
APrimitive::compile(CEnv& cenv)
{
	Value*       a       = llVal(cenv.compile(at(1)));
	Value*       b       = llVal(cenv.compile(at(2)));
	bool         isFloat = cenv.type(at(1))->str() == "Float";
	const string n       = at(0)->to<ASymbol*>()->str();

	// Binary arithmetic operations
	Instruction::BinaryOps op = (Instruction::BinaryOps)0;
	if (n == "+")   op = Instruction::Add;
	if (n == "-")   op = Instruction::Sub;
	if (n == "*")   op = Instruction::Mul;
	if (n == "and") op = Instruction::And;
	if (n == "or")  op = Instruction::Or;
	if (n == "xor") op = Instruction::Xor;
	if (n == "/")   op = isFloat ? Instruction::FDiv : Instruction::SDiv;
	if (n == "%")   op = isFloat ? Instruction::FRem : Instruction::SRem;
	if (op != 0) {
		Value* val = llengine(cenv)->builder.CreateBinOp(op, a, b);
		for (size_t i = 3; i < size(); ++i)
			val = llengine(cenv)->builder.CreateBinOp(op, val, llVal(cenv.compile(at(i))));
		return val;
	}

	// Comparison operations
	CmpInst::Predicate pred = (CmpInst::Predicate)0;
	if (n == "=")  pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ;
	if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ;
	if (n == ">")  pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT;
	if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE;
	if (n == "<")  pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT;
	if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE;
	if (pred != 0) {
		if (isFloat)
			return llengine(cenv)->builder.CreateFCmp(pred, a, b);
		else
			return llengine(cenv)->builder.CreateICmp(pred, a, b);
	}

	throw Error(loc, "unknown primitive");
}


/***************************************************************************
 * EVAL/REPL                                                               *
 ***************************************************************************/

const string
call(CEnv& cenv, CFunction f, AType* retT)
{
	void*       fp = llengine(cenv)->engine->getPointerToFunction(llFunc(f));
	const Type* t  = llType(retT);
	THROW_IF(!fp, Cursor(), "unable to get function pointer");
	THROW_IF(!t,  Cursor(), "function with non-primitive return type called");

	std::stringstream ss;
	if (t == Type::Int32Ty)
		ss << ((int32_t (*)())fp)();
	else if (t == Type::FloatTy)
		ss << showpoint << ((float (*)())fp)();
	else if (t == Type::Int1Ty)
		ss << (((bool (*)())fp)() ? "#t" : "#f");
	else
		ss << ((void* (*)())fp)();
	return ss.str();
}

CEnv*
newCenv(PEnv& penv, TEnv& tenv)
{
	LLVMEngine* engine = new LLVMEngine();
	CEnv*       cenv   = new CEnv(penv, tenv, engine);

	// Host provided allocation primitive prototypes
	std::vector<const Type*> argsT(1, Type::Int32Ty); // size
	argsT.push_back(Type::Int8Ty); // tag
	FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false);
	cenv->alloc = Function::Create(funcT, Function::ExternalLinkage,
			"tuplr_gc_allocate", engine->module);

	return cenv;
}

void
freeCenv(CEnv* cenv)
{
	Object::pool.collect(GC::Roots());
	delete (LLVMEngine*)cenv->engine();
	delete cenv;
}