/* Tuplr: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Compile AST to LLVM IR
 *
 * Compilation pass functions (lift/compile) that require direct use of LLVM
 * specific things are implemented here.  Generic compilation pass functions
 * are implemented in compile.cpp.
 */

#include <map>
#include <sstream>
#include <boost/format.hpp>
#include "llvm/Analysis/Verifier.h"
#include "llvm/Assembly/AsmAnnotationWriter.h"
#include "llvm/DerivedTypes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/Instructions.h"
#include "llvm/Module.h"
#include "llvm/ModuleProvider.h"
#include "llvm/PassManager.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Transforms/Scalar.h"
#include "tuplr.hpp"

using namespace llvm;
using namespace std;
using boost::format;

static inline Value*    llVal(CValue     v) { return static_cast<Value*>(v); }
static inline Function* llFunc(CFunction f) { return static_cast<Function*>(f); }

static const Type*
llType(const AType* t)
{
	if (t->kind == AType::PRIM) {
		if (t->at(0)->str() == "Nothing") return Type::VoidTy;
		if (t->at(0)->str() == "Bool")    return Type::Int1Ty;
		if (t->at(0)->str() == "Int")     return Type::Int32Ty;
		if (t->at(0)->str() == "Float")   return Type::FloatTy;
		throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'");
	} else if (t->kind == AType::EXPR && t->at(0)->str() == "Fn") {
		const AType* retT = t->at(2)->as<const AType*>();
		if (!llType(retT))
			return NULL;

		vector<const Type*> cprot;
		const ATuple* prot = t->at(1)->to<const ATuple*>();
		for (size_t i = 0; i < prot->size(); ++i) {
			const AType* at = prot->at(i)->to<const AType*>();
			const Type* lt = llType(at);
			if (lt)
				cprot.push_back(lt);
			else
				return NULL;
		}

		FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
		return PointerType::get(fT, 0);
	}
	return NULL; // non-primitive type
}


/***************************************************************************
 * LLVM Engine                                                             *
 ***************************************************************************/

struct LLVMEngine : public Engine {
	LLVMEngine()
		: module(new Module("tuplr"))
		, engine(ExecutionEngine::create(module))
		, emp(module)
		, opt(&emp)
	{
		// Set up optimiser pipeline
		const TargetData* target = engine->getTargetData();
		opt.add(new TargetData(*target));          // Register target arch
		opt.add(createInstructionCombiningPass()); // Simple optimizations
		opt.add(createReassociatePass());          // Reassociate expressions
		opt.add(createGVNPass());                  // Eliminate Common Subexpressions
		opt.add(createCFGSimplificationPass());    // Simplify control flow

		// Declare host provided allocation primitive
		std::vector<const Type*> argsT(1, Type::Int32Ty); // unsigned size
		argsT.push_back(Type::Int8Ty); // char tag
		FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false);
		alloc = Function::Create(funcT, Function::ExternalLinkage,
				"tuplr_gc_allocate", module);
	}

	CFunction startFunction(CEnv& cenv,
			const std::string& name, const AType* retT, const ATuple& argsT,
			const vector<string> argNames)
	{
		Function::LinkageTypes linkage = Function::ExternalLinkage;

		vector<const Type*> cprot;
		FOREACH(ATuple::const_iterator, i, argsT) {
			AType* at = (*i)->as<AType*>();
			THROW_IF(!llType(at), Cursor(), string("non-concrete parameter :: ")
				+ at->str())
			cprot.push_back(llType(at));
		}

		THROW_IF(!llType(retT), Cursor(), "return has non-concrete type");
		FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
		Function*     f  = Function::Create(fT, linkage, name, module);

		// Note f->getName() may be different from name
		// however LLVM chooses to mangle is fine, we keep a pointer

		// Set argument names in generated code
		Function::arg_iterator a = f->arg_begin();
		if (!argNames.empty())
			for (size_t i = 0; i != argsT.size(); ++a, ++i)
				a->setName(argNames.at(i));

		BasicBlock* bb = BasicBlock::Create("entry", f);
		builder.SetInsertPoint(bb);

		return f;
	}

	void finishFunction(CEnv& cenv, CFunction f, const AType* retT, CValue ret) {
		if (retT->concrete())
			builder.CreateRet(llVal(ret));
		else
			builder.CreateRetVoid();

		verifyFunction(*static_cast<Function*>(f));
		if (cenv.args.find("-g") == cenv.args.end())
			opt.run(*static_cast<Function*>(f));
	}

	void eraseFunction(CEnv& cenv, CFunction f) {
		if (f)
			llFunc(f)->eraseFromParent();
	}

	CValue compileCall(CEnv& cenv, CFunction f, const vector<CValue>& args) {
		const vector<Value*>& llArgs = *reinterpret_cast<const vector<Value*>*>(&args);
		return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end());
	}

	void writeModule(CEnv& cenv, std::ostream& os) {
		AssemblyAnnotationWriter writer;
		module->print(os, &writer);
	}

	const string call(CEnv& cenv, CFunction f, AType* retT) {
		void*       fp = engine->getPointerToFunction(llFunc(f));
		const Type* t  = llType(retT);
		THROW_IF(!fp, Cursor(), "unable to get function pointer");
		THROW_IF(!t,  Cursor(), "function with non-concrete return type called");

		std::stringstream ss;
		if (t == Type::Int32Ty)
			ss << ((int32_t (*)())fp)();
		else if (t == Type::FloatTy)
			ss << showpoint << ((float (*)())fp)();
		else if (t == Type::Int1Ty)
			ss << (((bool (*)())fp)() ? "#t" : "#f");
		else if (t != Type::VoidTy)
			ss << ((void* (*)())fp)();
		return ss.str();
	}

	Module*                module;
	ExecutionEngine*       engine;
	IRBuilder<>            builder;
	Function*              alloc;
	ExistingModuleProvider emp;
	FunctionPassManager    opt;
};

/// Create a new Engine (shared library entry point)
Engine*
tuplr_new_engine()
{
	return new LLVMEngine();
}

/// Free an Engine (shared library entry point)
void
tuplr_free_engine(Engine* engine)
{
	delete (LLVMEngine*)engine;
}


/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

#define COMPILE_LITERAL(CT, COMPILED) \
template<> CValue ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); }

// Literal template instantiations
COMPILE_LITERAL(int32_t, ConstantInt::get(Type::Int32Ty, val, true))
COMPILE_LITERAL(float,   ConstantFP::get(Type::FloatTy, val))
COMPILE_LITERAL(bool,    ConstantInt::get(Type::Int1Ty, val, false))

void
AFn::liftCall(CEnv& cenv, const AType& argsT)
{
	TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(this);
	assert(gt != cenv.tenv.genericTypes.end());
	LLVMEngine* engine      = reinterpret_cast<LLVMEngine*>(cenv.engine());
	AType*      genericType = new AType(*gt->second);
	AType*      thisType    = genericType;
	Subst       argsSubst;

	// Build and apply substitution to get concrete type for this call
	if (!genericType->concrete()) {
		argsSubst = cenv.tenv.buildSubst(genericType, argsT);
		thisType = argsSubst.apply(genericType)->as<AType*>();
	}

	THROW_IF(!thisType->concrete(), loc,
			string("call has non-concrete type %1%\n") + thisType->str());

	Object::pool.addRoot(thisType);
	if (impls.find(thisType))
		return;

	ATuple* protT = thisType->at(1)->as<ATuple*>();

	vector<string> argNames;
	for (size_t i = 0; i < prot()->size(); ++i)
		argNames.push_back(prot()->at(i)->str());

	// Write function declaration
	const string name = (this->name == "") ? cenv.penv.gensymstr("_fn") : this->name;
	Function* f = llFunc(cenv.engine()->startFunction(cenv, name,
			thisType->at(thisType->size()-1)->to<AType*>(),
			*protT, argNames));

	cenv.push();
	Subst oldSubst = cenv.tsubst;
	cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, subst));

//#define EXPLICIT_STACK_FRAMES 1

#ifdef EXPLICIT_STACK_FRAMES
	vector<const Type*> types;
	types.push_back(Type::Int8Ty);
	types.push_back(Type::Int8Ty);
	size_t s = 16; // stack frame size in bits
#endif

	// Bind argument values in CEnv
	vector<Value*> args;
	const_iterator p = prot()->begin();
	size_t i = 0;
	for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++i) {
		AType* t = protT->at(i)->as<AType*>();
		const Type* lt = llType(t);
		THROW_IF(!lt, loc, "untyped parameter\n");
		cenv.def((*p)->as<ASymbol*>(), *p, t, &*a);
#ifdef EXPLICIT_STACK_FRAMES
		types.push_back(lt);
		s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8));
#endif
	}

#ifdef EXPLICIT_STACK_FRAMES
	IRBuilder<> builder = engine->builder;

	// Scan out definitions
	for (size_t i = 0; i < size(); ++i) {
		ADef* def = at(i)->to<ADef*>();
		if (def) {
			const Type* lt = llType(cenv.type(def->at(2)));
			THROW_IF(!lt, loc, "untyped definition\n");
			types.push_back(lt);
			s += std::max(lt->getPrimitiveSizeInBits(), unsigned(8));
		}
	}

	// Create stack frame
	StructType* frameT    = StructType::get(types, false);
	Value*      tag       = ConstantInt::get(Type::Int8Ty, GC::TAG_FRAME);
	Value*      frameSize = ConstantInt::get(Type::Int32Ty, s / 8);
	Value*      frame     = builder.CreateCall2(engine->alloc, frameSize, tag, "frame");
	Value*      framePtr  = builder.CreateBitCast(frame, PointerType::get(frameT, 0));

	// Bind parameter values in stack frame
	i = 2;
	for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++i) {
		Value* v = builder.CreateStructGEP(framePtr, i, "arg");
		builder.CreateStore(&*a, v);
	}
#endif

	// Write function body
	try {
		// Define value first for recursion
		cenv.precompile(this, f);
		impls.push_back(make_pair(thisType, f));
		CValue retVal = NULL;
		for (size_t i = 2; i < size(); ++i)
			retVal = cenv.compile(at(i));
		cenv.engine()->finishFunction(cenv, f, cenv.type(at(size()-1)), retVal);
	} catch (Error& e) {
		f->eraseFromParent(); // Error reading body, remove function
		cenv.pop();
		throw e;
	}
	cenv.tsubst = oldSubst;
	cenv.pop();
}

CValue
AFn::compile(CEnv& cenv)
{
	AType* aFnT = cenv.type(this);
	const Type* fnT = llType(aFnT);
	return fnT ? static_cast<Function*>(impls.find(aFnT)) : NULL;

	/*vector<const Type*> types;
	types.push_back(PointerType::get(fnT, 0));
	types.push_back(PointerType::get(Type::VoidTy, 0));
	LLVMEngine* engine  = reinterpret_cast<LLVMEngine*>(cenv.engine());
	IRBuilder<> builder = engine->builder;
	Value*      tag     = ConstantInt::get(Type::Int8Ty, GC::TAG_FRAME);
	StructType* tupT    = StructType::get(types, false);
	Value*      tupSize = ConstantInt::get(Type::Int32Ty, sizeof(void*) * 2);
	Value*      tup     = builder.CreateCall2(engine->alloc, tupSize, tag, "fn");
	Value*      tupPtr  = builder.CreateBitCast(tup, PointerType::get(tupT, 0));
	return tupPtr;*/
}

CValue
AIf::compile(CEnv& cenv)
{
	typedef vector< pair<Value*, BasicBlock*> > Branches;
	LLVMEngine* engine  = reinterpret_cast<LLVMEngine*>(cenv.engine());
	Function*   parent  = engine->builder.GetInsertBlock()->getParent();
	BasicBlock* mergeBB = BasicBlock::Create("endif");
	BasicBlock* nextBB  = NULL;
	Branches    branches;
	for (size_t i = 1; i < size() - 1; i += 2) {
		Value*      condV  = llVal(cenv.compile(at(i)));
		BasicBlock* thenBB = BasicBlock::Create((format("then%1%") % ((i+1)/2)).str());

		nextBB = BasicBlock::Create((format("else%1%") % ((i+1)/2)).str());

		engine->builder.CreateCondBr(condV, thenBB, nextBB);

		// Emit then block for this condition
		parent->getBasicBlockList().push_back(thenBB);
		engine->builder.SetInsertPoint(thenBB);
		Value* thenV = llVal(cenv.compile(at(i+1)));
		engine->builder.CreateBr(mergeBB);
		branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock()));

		parent->getBasicBlockList().push_back(nextBB);
		engine->builder.SetInsertPoint(nextBB);
	}

	// Emit final else block
	engine->builder.SetInsertPoint(nextBB);
	Value* elseV = llVal(cenv.compile(at(size() - 1)));
	engine->builder.CreateBr(mergeBB);
	branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock()));

	// Emit merge block (Phi node)
	parent->getBasicBlockList().push_back(mergeBB);
	engine->builder.SetInsertPoint(mergeBB);
	PHINode* pn = engine->builder.CreatePHI(llType(cenv.type(this)), "ifval");

	FOREACH(Branches::iterator, i, branches)
		pn->addIncoming(i->first, i->second);

	return pn;
}

CValue
APrimitive::compile(CEnv& cenv)
{
	LLVMEngine*  engine  = reinterpret_cast<LLVMEngine*>(cenv.engine());
	Value*       a       = llVal(cenv.compile(at(1)));
	Value*       b       = llVal(cenv.compile(at(2)));
	bool         isFloat = cenv.type(at(1))->str() == "Float";
	const string n       = at(0)->to<ASymbol*>()->str();

	// Binary arithmetic operations
	Instruction::BinaryOps op = (Instruction::BinaryOps)0;
	if (n == "+")   op = Instruction::Add;
	if (n == "-")   op = Instruction::Sub;
	if (n == "*")   op = Instruction::Mul;
	if (n == "and") op = Instruction::And;
	if (n == "or")  op = Instruction::Or;
	if (n == "xor") op = Instruction::Xor;
	if (n == "/")   op = isFloat ? Instruction::FDiv : Instruction::SDiv;
	if (n == "%")   op = isFloat ? Instruction::FRem : Instruction::SRem;
	if (op != 0) {
		Value* val = engine->builder.CreateBinOp(op, a, b);
		for (size_t i = 3; i < size(); ++i)
			val = engine->builder.CreateBinOp(op, val, llVal(cenv.compile(at(i))));
		return val;
	}

	// Comparison operations
	CmpInst::Predicate pred = (CmpInst::Predicate)0;
	if (n == "=")  pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ;
	if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ;
	if (n == ">")  pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT;
	if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE;
	if (n == "<")  pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT;
	if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE;
	if (pred != 0) {
		if (isFloat)
			return engine->builder.CreateFCmp(pred, a, b);
		else
			return engine->builder.CreateICmp(pred, a, b);
	}

	throw Error(loc, "unknown primitive");
}