/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Compile to LLVM IR
 */

#define __STDC_LIMIT_MACROS    1
#define __STDC_CONSTANT_MACROS 1

#include <map>
#include <sstream>
#include <boost/format.hpp>
#include "llvm/Value.h"
#include "llvm/Analysis/Verifier.h"
#include "llvm/Assembly/AsmAnnotationWriter.h"
#include "llvm/DerivedTypes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JIT.h"
#include "llvm/Instructions.h"
#include "llvm/LLVMContext.h"
#include "llvm/Module.h"
#include "llvm/PassManager.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Target/TargetSelect.h"
#include "llvm/Transforms/Scalar.h"
#include "resp.hpp"

using namespace llvm;
using namespace std;
using boost::format;

/***************************************************************************
 * LLVM Engine                                                             *
 ***************************************************************************/

struct LLVMEngine : public Engine {
	LLVMEngine()
		: builder(context)
		, labelIndex(1)
	{
		InitializeNativeTarget();
		module = new Module("resp", context);
		engine = EngineBuilder(module).create();
		opt    = new FunctionPassManager(module);

		// Set up optimiser pipeline
		const TargetData* target = engine->getTargetData();
		opt->add(new TargetData(*target));          // Register target arch
		opt->add(createInstructionCombiningPass()); // Simple optimizations
		opt->add(createReassociatePass());          // Reassociate expressions
		opt->add(createGVNPass());                  // Eliminate Common Subexpressions
		opt->add(createCFGSimplificationPass());    // Simplify control flow

		// Declare host provided allocation primitive
		std::vector<const Type*> argsT(1, Type::getInt32Ty(context)); // unsigned size
		FunctionType* funcT = FunctionType::get(PointerType::get(Type::getInt8Ty(context), 0), argsT, false);
		alloc = Function::Create(funcT, Function::ExternalLinkage,
				"resp_gc_allocate", module);
	}

	~LLVMEngine()
	{
		delete engine;
		delete opt;
	}

	inline Value*    llVal(CVal v)   { return static_cast<Value*>(v); }
	inline Function* llFunc(CFunc f) { return static_cast<Function*>(f); }

	const Type*
	llType(const AType* t)
	{
		if (t == NULL) {
			return NULL;
		} else if (t->kind == AType::PRIM) {
			if (t->head()->str() == "Nothing") return Type::getVoidTy(context);
			if (t->head()->str() == "Bool")    return Type::getInt1Ty(context);
			if (t->head()->str() == "Int")     return Type::getInt32Ty(context);
			if (t->head()->str() == "Float")   return Type::getFloatTy(context);
			if (t->head()->str() == "String")  return PointerType::get(Type::getInt8Ty(context), NULL);
			if (t->head()->str() == "Quote")   return PointerType::get(Type::getInt8Ty(context), NULL);
			throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'");
		} else if (t->kind == AType::EXPR && t->head()->str() == "Fn") {
			AType::const_iterator i     = t->begin();
			const ATuple*         protT = (*++i)->to_tuple();
			const AType*          retT  = (*++i)->as_type();
			if (!llType(retT))
				return NULL;

			vector<const Type*> cprot;
			FOREACHP(ATuple::const_iterator, i, protT) {
				const Type* lt = llType((*i)->to_type());
				if (!lt)
					return NULL;
				cprot.push_back(lt);
			}

			return PointerType::get(FunctionType::get(llType(retT), cprot, false), 0);
		} else if (t->kind == AType::EXPR && isupper(t->head()->str()[0])) {
			vector<const Type*> ctypes;
			ctypes.push_back(PointerType::get(Type::getInt8Ty(context), NULL)); // RTTI
			for (AType::const_iterator i = t->iter_at(1); i != t->end(); ++i) {
				const Type* lt = llType((*i)->to_type());
				if (!lt)
					return NULL;
				ctypes.push_back(lt);
			}

			return PointerType::get(StructType::get(context, ctypes, false), 0);
		} else if (t->kind == AType::NAME) {
			assert(false);
		}
		assert(false);
		return PointerType::get(Type::getInt8Ty(context), NULL);
	}

	CFunc startFunction(
		CEnv& cenv, const std::string& name, const ATuple* args, const AType* type)
	{
		const AType* argsT = type->prot()->as_type();
		const AType* retT  = type->list_last()->as_type();

		Function::LinkageTypes linkage = Function::ExternalLinkage;

		vector<const Type*> cprot;
		FOREACHP(ATuple::const_iterator, i, argsT) {
			const AType* at = (*i)->as_type();
			THROW_IF(!llType(at), Cursor(), string("non-concrete parameter :: ")
				+ at->str())
			cprot.push_back(llType(at));
		}

		THROW_IF(!llType(retT), Cursor(),
				(format("return has non-concrete type `%1%'") % retT->str()).str());

		const string llName = (name == "") ? cenv.penv.gensymstr("_fn") : name;

		FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
		Function*     f  = Function::Create(fT, linkage, llName, module);

		// Note f->getName() may be different from llName
		// however LLVM chooses to mangle is fine, we keep a pointer

		// Set argument names in generated code
		Function::arg_iterator a = f->arg_begin();
		for (ATuple::const_iterator i = args->begin(); i != args->end(); ++a, ++i)
			a->setName((*i)->as_symbol()->cppstr);

		BasicBlock* bb = BasicBlock::Create(context, "entry", f);
		builder.SetInsertPoint(bb);

		return f;
	}
	
	void pushFunctionArgs(CEnv& cenv, const ATuple* prot, const AType* type, CFunc f);

	void appendBlock(LLVMEngine* engine, Function* function, BasicBlock* block) {
		function->getBasicBlockList().push_back(block);
		engine->builder.SetInsertPoint(block);
	}

	void finishFunction(CEnv& cenv, CFunc f, CVal ret) {
		builder.CreateRet(llVal(ret));
		if (verifyFunction(*static_cast<Function*>(f), llvm::PrintMessageAction)) {
			module->dump();
			throw Error(Cursor(), "Broken module");
		}
		if (cenv.args.find("-g") == cenv.args.end())
			opt->run(*static_cast<Function*>(f));
	}

	void eraseFunction(CEnv& cenv, CFunc f) {
		if (f)
			llFunc(f)->eraseFromParent();
	}

	CVal compileCall(CEnv& cenv, CFunc f, const AType* funcT, const vector<CVal>& args) {
		vector<Value*> llArgs(*reinterpret_cast<const vector<Value*>*>(&args));
		Value* closure = builder.CreateBitCast(llArgs[0],
				llType(funcT->prot()->head()->as_type()),
				cenv.penv.gensymstr("you"));
		llArgs[0] = closure;
		return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end());
	}

	CVal compileCons(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields);
	CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
	CVal compileLiteral(CEnv& cenv, const AST* lit);
	CVal compileString(CEnv& cenv, const char* str);
	CVal compilePrimitive(CEnv& cenv, const ATuple* prim);
	CVal compileIf(CEnv& cenv, const ATuple* aif);
	CVal compileMatch(CEnv& cenv, const ATuple* match);
	CVal compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal val);
	CVal compileGlobalGet(CEnv& cenv, const string& sym, CVal val);

	typedef pair<Value*, BasicBlock*> IfBranch;
	typedef vector<IfBranch>          IfBranches;

	struct LLVMIfState {
		LLVMIfState(BasicBlock* m, Function* p) : mergeBB(m), parent(p) {}
		BasicBlock* mergeBB;
		Function*   parent;
		IfBranches  branches;
	};

	IfState
	compileIfStart(CEnv& cenv);
	
	void
	compileIfBranch(CEnv& cenv, IfState state, CVal condV, const AST* then);

	CVal
	compileIfEnd(CEnv& cenv, IfState state, CVal elseV, const AType* type);
	
	void writeModule(CEnv& cenv, std::ostream& os) {
		AssemblyAnnotationWriter writer;
		llvm::raw_os_ostream raw_stream(os);
		module->print(raw_stream, &writer);
	}

	const string call(CEnv& cenv, CFunc f, const AType* retT) {
		void*       fp = engine->getPointerToFunction(llFunc(f));
		const Type* t  = llType(retT);
		THROW_IF(!fp, Cursor(), "unable to get function pointer");
		THROW_IF(!t,  Cursor(), "function with non-concrete return type called");

		std::stringstream ss;
		if (t == Type::getInt32Ty(context)) {
			ss << ((int32_t (*)())fp)();
		} else if (t == Type::getFloatTy(context)) {
			ss << showpoint << ((float (*)())fp)();
		} else if (t == Type::getInt1Ty(context)) {
			ss << (((bool (*)())fp)() ? "#t" : "#f");
		} else if (retT->head()->str() == "String") {
			const std::string s(((char* (*)())fp)());
			ss << "\"";
			for (std::string::const_iterator i = s.begin(); i != s.end(); ++i) {
				switch (*i) {
				case '\"':
				case '\\':
					ss << '\\';
				default:
					ss << *i;
					break;
				}
			}
			ss << "\"";
		} else if (t != Type::getVoidTy(context)) {
			ss << ((void* (*)())fp)();
		} else {
			((void (*)())fp)();
		}
		return ss.str();
	}

	LLVMContext             context;
	Module*                 module;
	ExecutionEngine*        engine;
	IRBuilder<>             builder;
	Function*               alloc;
	FunctionPassManager*    opt;

	unsigned labelIndex;
};

Engine*
resp_new_llvm_engine()
{
	return new LLVMEngine();
}

/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

/** Convert a size in bits to bytes, rounding up as necessary */
static inline size_t
bitsToBytes(size_t bits)
{
	return ((bits % 8 == 0) ? bits : (((bits / 8) + 1) * 8)) / 8;
}

CVal
LLVMEngine::compileCons(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields)
{
	// Find size of memory required
	size_t s = engine->getTargetData()->getTypeSizeInBits(PointerType::get(Type::getInt8Ty(context), NULL));
	assert(type->begin() != type->end());
	for (AType::const_iterator i = type->iter_at(1); i != type->end(); ++i)
		s += engine->getTargetData()->getTypeSizeInBits(llType((*i)->as_type()));

	// Allocate struct
	Value* structSize = ConstantInt::get(Type::getInt32Ty(context), bitsToBytes(s));
	Value* mem        = builder.CreateCall(alloc, structSize, "tupMem");
	Value* structPtr  = builder.CreateBitCast(mem, llType(type), "tup");

	// Set struct fields
	if (rtti)
		builder.CreateStore((Value*)rtti, builder.CreateStructGEP(structPtr, 0, "rtti"));
	size_t i = 1;
	for (vector<CVal>::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i) {
		builder.CreateStore(llVal(*f),
				builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str()));
	}

	return structPtr;
}

CVal
LLVMEngine::compileDot(CEnv& cenv, CVal tup, int32_t index)
{
	Value* ptr = builder.CreateStructGEP(llVal(tup), index, "dotPtr");
	return builder.CreateLoad(ptr, 0, "dotVal");
}

CVal
LLVMEngine::compileLiteral(CEnv& cenv, const AST* lit)
{
	switch (lit->tag()) {
	case T_BOOL:
		return ConstantInt::get(Type::getInt1Ty(context), ((const ALiteral<bool>*)lit)->val);
	case T_FLOAT:
		return ConstantFP::get(Type::getFloatTy(context), ((const ALiteral<float>*)lit)->val);
	case T_INT32:
		return ConstantInt::get(Type::getInt32Ty(context), ((const ALiteral<int32_t>*)lit)->val, true);
	default:
		throw Error(lit->loc, "Unknown literal type");
	}
}

CVal
LLVMEngine::compileString(CEnv& cenv, const char* str)
{
	return builder.CreateGlobalStringPtr(str);
}

void
LLVMEngine::pushFunctionArgs(CEnv& cenv, const ATuple* prot, const AType* type, CFunc cfunc)
{
	cenv.push();

	const AType* argsT = type->prot()->as_type();
	Function*    f     = llFunc(cfunc);
	
	// Bind argument values in CEnv
	ATuple::const_iterator p  = prot->begin();
	ATuple::const_iterator pT = argsT->begin();
	assert(prot->size() == argsT->size());
	assert(prot->size() == f->num_args());
	for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++pT) {
		const AType* t  = cenv.resolveType((*pT)->as_type());
		THROW_IF(!llType(t), (*p)->loc, "untyped parameter\n");
		cenv.def((*p)->as_symbol(), *p, t, &*a);
	}
}

IfState
LLVMEngine::compileIfStart(CEnv& cenv)
{
	LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	return new LLVMIfState(BasicBlock::Create(context, "endif"),
	                       engine->builder.GetInsertBlock()->getParent());
}

void
LLVMEngine::compileIfBranch(CEnv& cenv, IfState s, CVal condV, const AST* then)
{
	LLVMIfState* state  = (LLVMIfState*)s;
	LLVMEngine*  engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	BasicBlock*  thenBB = BasicBlock::Create(context, (format("then%1%") % labelIndex).str());
	BasicBlock*  nextBB = BasicBlock::Create(context, (format("else%1%") % labelIndex).str());

	++labelIndex;

	engine->builder.CreateCondBr(llVal(condV), thenBB, nextBB);

	// Emit then block for this condition
	appendBlock(engine, state->parent, thenBB);
	Value* thenV = llVal(resp_compile(cenv, then));
	engine->builder.CreateBr(state->mergeBB);
	state->branches.push_back(make_pair(thenV, thenBB));
	
	appendBlock(engine, state->parent, nextBB);
}

CVal
LLVMEngine::compileIfEnd(CEnv& cenv, IfState s, CVal elseV, const AType* type)
{
	LLVMIfState* state  = (LLVMIfState*)s;
	LLVMEngine*  engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	
	// Emit end of final else block
	engine->builder.CreateBr(state->mergeBB);
	state->branches.push_back(make_pair(llVal(elseV), engine->builder.GetInsertBlock()));

	// Emit merge block (Phi node)
	appendBlock(engine, state->parent, state->mergeBB);
	PHINode* pn = engine->builder.CreatePHI(llType(type), "ifval");
	FOREACH(IfBranches::iterator, i, state->branches)
		pn->addIncoming(i->first, i->second);

	return pn;
}

CVal
LLVMEngine::compileMatch(CEnv& cenv, const ATuple* match)
{
	LLVMEngine* engine  = reinterpret_cast<LLVMEngine*>(cenv.engine());
	IfState     state   = compileIfStart(cenv);
	CVal        matchee = resp_compile(cenv, match->list_ref(1));
	Value*      rtti    = llVal(compileDot(cenv, matchee, 0));
	
	size_t idx = 1;
	for (ATuple::const_iterator i = match->iter_at(2); i != match->end(); ++idx) {
		const AST*     pat  = *i++;
		const AST*     body = *i++;
		const ASymbol* sym  = pat->as_tuple()->head()->as_symbol();
		const AType*   patT = new AType(sym, 0, Cursor());
		
		Value* typeV = llVal(resp_compile(cenv, patT));
		Value* condV = engine->builder.CreateICmp(CmpInst::ICMP_EQ, rtti, typeV);

		compileIfBranch(cenv, state, condV, body);
	}

	const AType* type  = cenv.type(match);
	CVal         elseV = Constant::getNullValue(llType(type));
	return compileIfEnd(cenv, state, elseV, type);
}

CVal
LLVMEngine::compilePrimitive(CEnv& cenv, const ATuple* prim)
{
	ATuple::const_iterator i = prim->begin();

	LLVMEngine*  engine  = reinterpret_cast<LLVMEngine*>(cenv.engine());
	bool         isFloat = cenv.type(*++i)->str() == "Float";
	Value*       a       = llVal(resp_compile(cenv, *i++));
	Value*       b       = llVal(resp_compile(cenv, *i++));
	const string n       = prim->head()->to_symbol()->str();

	// Binary arithmetic operations
	Instruction::BinaryOps op = (Instruction::BinaryOps)0;
	if (n == "+")   op = Instruction::Add;
	if (n == "-")   op = Instruction::Sub;
	if (n == "*")   op = Instruction::Mul;
	if (n == "and") op = Instruction::And;
	if (n == "or")  op = Instruction::Or;
	if (n == "xor") op = Instruction::Xor;
	if (n == "/")   op = isFloat ? Instruction::FDiv : Instruction::SDiv;
	if (n == "%")   op = isFloat ? Instruction::FRem : Instruction::SRem;
	if (op != 0) {
		Value* val = engine->builder.CreateBinOp(op, a, b);
		while (i != prim->end())
			val = engine->builder.CreateBinOp(op, val, llVal(resp_compile(cenv, *i++)));
		return val;
	}

	// Comparison operations
	CmpInst::Predicate pred = (CmpInst::Predicate)0;
	if (n == "=")  pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ;
	if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ;
	if (n == ">")  pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT;
	if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE;
	if (n == "<")  pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT;
	if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE;
	if (pred != 0) {
		if (isFloat)
			return engine->builder.CreateFCmp(pred, a, b);
		else
			return engine->builder.CreateICmp(pred, a, b);
	}

	throw Error(prim->loc, "unknown primitive");
}

CVal
LLVMEngine::compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal val)
{
	LLVMEngine*     engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	GlobalVariable* global = new GlobalVariable(*module, llType(type), false,
			GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), sym);

	Value* valPtr = builder.CreateBitCast(llVal(val), llType(type), "globalPtr");

	engine->builder.CreateStore(valPtr, global);
	return global;
}

CVal
LLVMEngine::compileGlobalGet(CEnv& cenv, const string& sym, CVal val)
{
	LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	return engine->builder.CreateLoad(llVal(val), sym + "Ptr");
}