/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Compile to LLVM IR and/or execute via JIT
 */

#define __STDC_LIMIT_MACROS    1
#define __STDC_CONSTANT_MACROS 1

#include <map>
#include <sstream>
#include <string>
#include <vector>

#include <boost/format.hpp>

#include "llvm/Value.h"
#include "llvm/Analysis/Verifier.h"
#include "llvm/Assembly/AsmAnnotationWriter.h"
#include "llvm/DerivedTypes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JIT.h"
#include "llvm/Instructions.h"
#include "llvm/LLVMContext.h"
#include "llvm/Module.h"
#include "llvm/PassManager.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Target/TargetSelect.h"
#include "llvm/Transforms/Scalar.h"

#include "resp.hpp"

using namespace llvm;
using namespace std;
using boost::format;

/** LLVM Engine (Compiler and JIT) */
struct LLVMEngine : public Engine {
	LLVMEngine();
	virtual ~LLVMEngine();

	CFunc startFn(CEnv& cenv, const string& name, const ATuple* args, const AType* type);
	void  pushFnArgs(CEnv& cenv, const ATuple* prot, const AType* type, CFunc f);
	void  finishFn(CEnv& cenv, CFunc f, CVal ret);
	void  eraseFn(CEnv& cenv, CFunc f);

	CVal    compileCall(CEnv& cenv, CFunc f, const AType* funcT, const vector<CVal>& args);
	CVal    compileCons(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields);
	CVal    compileDot(CEnv& cenv, CVal tup, int32_t index);
	CVal    compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AType* t);
	CVal    compileGlobalGet(CEnv& cenv, const string& s, CVal v);
	IfState compileIfStart(CEnv& cenv);
	void    compileIfBranch(CEnv& cenv, IfState state, CVal condV, const AST* then);
	CVal    compileIfEnd(CEnv& cenv, IfState state, CVal elseV, const AType* type);
	CVal    compileIsA(CEnv& cenv, CVal rtti, const ASymbol* tag);
	CVal    compileLiteral(CEnv& cenv, const AST* lit);
	CVal    compilePrimitive(CEnv& cenv, const ATuple* prim);
	CVal    compileString(CEnv& cenv, const char* str);

	void writeModule(CEnv& cenv, std::ostream& os);

	const string call(CEnv& cenv, CFunc f, const AType* retT);

private:
	typedef pair<Value*, BasicBlock*> IfBranch;
	typedef vector<IfBranch>          IfBranches;

	struct LLVMIfState {
		LLVMIfState(BasicBlock* m, Function* p) : mergeBB(m), parent(p) {}
		BasicBlock* mergeBB;
		Function*   parent;
		IfBranches  branches;
	};

	void appendBlock(LLVMEngine* engine, Function* function, BasicBlock* block) {
		function->getBasicBlockList().push_back(block);
		engine->builder.SetInsertPoint(block);
	}

	inline Value*    llVal(CVal v)   { return static_cast<Value*>(v); }
	inline Function* llFunc(CFunc f) { return static_cast<Function*>(f); }
	const Type*      llType(const AType* t);

	LLVMContext             context;
	Module*                 module;
	ExecutionEngine*        engine;
	IRBuilder<>             builder;
	Function*               alloc;
	FunctionPassManager*    opt;

	unsigned labelIndex;
};

LLVMEngine::LLVMEngine()
	: builder(context)
	, labelIndex(1)
{
	InitializeNativeTarget();
	module = new Module("resp", context);
	engine = EngineBuilder(module).create();
	opt    = new FunctionPassManager(module);

	// Set up optimiser pipeline
	const TargetData* target = engine->getTargetData();
	opt->add(new TargetData(*target));          // Register target arch
	opt->add(createInstructionCombiningPass()); // Simple optimizations
	opt->add(createReassociatePass());          // Reassociate expressions
	opt->add(createGVNPass());                  // Eliminate Common Subexpressions
	opt->add(createCFGSimplificationPass());    // Simplify control flow

	// Declare host provided allocation primitive
	std::vector<const Type*> argsT(1, Type::getInt32Ty(context)); // unsigned size
	FunctionType* funcT = FunctionType::get(PointerType::get(Type::getInt8Ty(context), 0), argsT, false);
	alloc = Function::Create(funcT, Function::ExternalLinkage,
	                         "resp_gc_allocate", module);
}

LLVMEngine::~LLVMEngine()
{
	delete engine;
	delete opt;
}

const Type*
LLVMEngine::llType(const AType* t)
{
	if (t == NULL) {
		return NULL;
	} else if (t->kind == AType::PRIM) {
		if (t->head()->str() == "Nothing") return Type::getVoidTy(context);
		if (t->head()->str() == "Bool")    return Type::getInt1Ty(context);
		if (t->head()->str() == "Int")     return Type::getInt32Ty(context);
		if (t->head()->str() == "Float")   return Type::getFloatTy(context);
		if (t->head()->str() == "String")  return PointerType::get(Type::getInt8Ty(context), NULL);
		if (t->head()->str() == "Quote")   return PointerType::get(Type::getInt8Ty(context), NULL);
		throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'");
	} else if (t->kind == AType::EXPR && t->head()->str() == "Fn") {
		AType::const_iterator i     = t->begin();
		const ATuple*         protT = (*++i)->to_tuple();
		const AType*          retT  = (*++i)->as_type();
		if (!llType(retT))
			return NULL;

		vector<const Type*> cprot;
		FOREACHP(ATuple::const_iterator, i, protT) {
			const Type* lt = llType((*i)->to_type());
			if (!lt)
				return NULL;
			cprot.push_back(lt);
		}

		return PointerType::get(FunctionType::get(llType(retT), cprot, false), 0);
	} else if (t->kind == AType::EXPR && isupper(t->head()->str()[0])) {
		vector<const Type*> ctypes;
		ctypes.push_back(PointerType::get(Type::getInt8Ty(context), NULL)); // RTTI
		for (AType::const_iterator i = t->iter_at(1); i != t->end(); ++i) {
			const Type* lt = llType((*i)->to_type());
			if (!lt)
				return NULL;
			ctypes.push_back(lt);
		}

		return PointerType::get(StructType::get(context, ctypes, false), 0);
	} else if (t->kind == AType::NAME) {
		assert(false);
	}
	assert(false);
	return PointerType::get(Type::getInt8Ty(context), NULL);
}

/** Convert a size in bits to bytes, rounding up as necessary */
static inline size_t
bitsToBytes(size_t bits)
{
	return ((bits % 8 == 0) ? bits : (((bits / 8) + 1) * 8)) / 8;
}

CVal
LLVMEngine::compileCall(CEnv& cenv, CFunc f, const AType* funcT, const vector<CVal>& args)
{
	vector<Value*> llArgs(*reinterpret_cast<const vector<Value*>*>(&args));
	Value* closure = builder.CreateBitCast(llArgs[0],
	                                       llType(funcT->prot()->head()->as_type()),
	                                       cenv.penv.gensymstr("you"));
	llArgs[0] = closure;
	return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end());
}

CVal
LLVMEngine::compileCons(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields)
{
	// Find size of memory required
	size_t s = engine->getTargetData()->getTypeSizeInBits(
		PointerType::get(Type::getInt8Ty(context), NULL));
	assert(type->begin() != type->end());
	for (AType::const_iterator i = type->iter_at(1); i != type->end(); ++i)
		s += engine->getTargetData()->getTypeSizeInBits(llType((*i)->as_type()));

	// Allocate struct
	Value* structSize = ConstantInt::get(Type::getInt32Ty(context), bitsToBytes(s));
	Value* mem        = builder.CreateCall(alloc, structSize, "tupMem");
	Value* structPtr  = builder.CreateBitCast(mem, llType(type), "tup");

	// Set struct fields
	if (rtti)
		builder.CreateStore((Value*)rtti, builder.CreateStructGEP(structPtr, 0, "rtti"));
	size_t i = 1;
	for (vector<CVal>::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i) {
		builder.CreateStore(llVal(*f),
				builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str()));
	}

	return structPtr;
}

CVal
LLVMEngine::compileDot(CEnv& cenv, CVal tup, int32_t index)
{
	Value* ptr = builder.CreateStructGEP(llVal(tup), index, "dotPtr");
	return builder.CreateLoad(ptr, 0, "dotVal");
}

CVal
LLVMEngine::compileLiteral(CEnv& cenv, const AST* lit)
{
	switch (lit->tag()) {
	case T_BOOL:
		return ConstantInt::get(Type::getInt1Ty(context), ((const ALiteral<bool>*)lit)->val);
	case T_FLOAT:
		return ConstantFP::get(Type::getFloatTy(context), ((const ALiteral<float>*)lit)->val);
	case T_INT32:
		return ConstantInt::get(Type::getInt32Ty(context), ((const ALiteral<int32_t>*)lit)->val, true);
	default:
		throw Error(lit->loc, "Unknown literal type");
	}
}

CVal
LLVMEngine::compileString(CEnv& cenv, const char* str)
{
	return builder.CreateGlobalStringPtr(str);
}

CFunc
LLVMEngine::startFn(
	CEnv& cenv, const std::string& name, const ATuple* args, const AType* type)
{
	const AType* argsT = type->prot()->as_type();
	const AType* retT  = type->list_last()->as_type();

	Function::LinkageTypes linkage = Function::ExternalLinkage;

	vector<const Type*> cprot;
	FOREACHP(ATuple::const_iterator, i, argsT) {
		const AType* at = (*i)->as_type();
		THROW_IF(!llType(at), Cursor(), string("non-concrete parameter :: ")
		         + at->str())
			cprot.push_back(llType(at));
	}

	THROW_IF(!llType(retT), Cursor(),
	         (format("return has non-concrete type `%1%'") % retT->str()).str());

	const string llName = (name == "") ? cenv.penv.gensymstr("_fn") : name;

	FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
	Function*     f  = Function::Create(fT, linkage, llName, module);

	// Note f->getName() may be different from llName
	// however LLVM chooses to mangle is fine, we keep a pointer

	// Set argument names in generated code
	Function::arg_iterator a = f->arg_begin();
	for (ATuple::const_iterator i = args->begin(); i != args->end(); ++a, ++i)
		a->setName((*i)->as_symbol()->sym());

	BasicBlock* bb = BasicBlock::Create(context, "entry", f);
	builder.SetInsertPoint(bb);

	return f;
}

void
LLVMEngine::pushFnArgs(CEnv& cenv, const ATuple* prot, const AType* type, CFunc cfunc)
{
	cenv.push();

	const AType* argsT = type->prot()->as_type();
	Function*    f     = llFunc(cfunc);

	// Bind argument values in CEnv
	ATuple::const_iterator p  = prot->begin();
	ATuple::const_iterator pT = argsT->begin();
	assert(prot->size() == argsT->size());
	assert(prot->size() == f->num_args());
	for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++pT) {
		const AType* t  = cenv.resolveType((*pT)->as_type());
		THROW_IF(!llType(t), (*p)->loc, "untyped parameter\n");
		cenv.def((*p)->as_symbol(), *p, t, &*a);
	}
}

void
LLVMEngine::finishFn(CEnv& cenv, CFunc f, CVal ret)
{
	builder.CreateRet(llVal(ret));
	if (verifyFunction(*static_cast<Function*>(f), llvm::PrintMessageAction)) {
		module->dump();
		throw Error(Cursor(), "Broken module");
	}
	if (cenv.args.find("-g") == cenv.args.end())
		opt->run(*static_cast<Function*>(f));
}

void
LLVMEngine::eraseFn(CEnv& cenv, CFunc f)
{
	if (f)
		llFunc(f)->eraseFromParent();
}

IfState
LLVMEngine::compileIfStart(CEnv& cenv)
{
	LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	return new LLVMIfState(BasicBlock::Create(context, "endif"),
	                       engine->builder.GetInsertBlock()->getParent());
}

void
LLVMEngine::compileIfBranch(CEnv& cenv, IfState s, CVal condV, const AST* then)
{
	LLVMIfState* state  = (LLVMIfState*)s;
	LLVMEngine*  engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	BasicBlock*  thenBB = BasicBlock::Create(context, (format("then%1%") % labelIndex).str());
	BasicBlock*  nextBB = BasicBlock::Create(context, (format("else%1%") % labelIndex).str());

	++labelIndex;

	engine->builder.CreateCondBr(llVal(condV), thenBB, nextBB);

	// Emit then block for this condition
	appendBlock(engine, state->parent, thenBB);
	Value* thenV = llVal(resp_compile(cenv, then));
	engine->builder.CreateBr(state->mergeBB);
	state->branches.push_back(make_pair(thenV, thenBB));

	appendBlock(engine, state->parent, nextBB);
}

CVal
LLVMEngine::compileIfEnd(CEnv& cenv, IfState s, CVal elseV, const AType* type)
{
	LLVMIfState* state  = (LLVMIfState*)s;
	LLVMEngine*  engine = reinterpret_cast<LLVMEngine*>(cenv.engine());

	if (!elseV)
		elseV = Constant::getNullValue(llType(type));

	// Emit end of final else block
	engine->builder.CreateBr(state->mergeBB);
	state->branches.push_back(make_pair(llVal(elseV), engine->builder.GetInsertBlock()));

	// Emit merge block (Phi node)
	appendBlock(engine, state->parent, state->mergeBB);
	PHINode* pn = engine->builder.CreatePHI(llType(type), "ifval");
	FOREACH(IfBranches::iterator, i, state->branches)
		pn->addIncoming(i->first, i->second);

	return pn;
}

CVal
LLVMEngine::compileIsA(CEnv& cenv, CVal rtti, const ASymbol* tag)
{
	LLVMEngine*  engine  = reinterpret_cast<LLVMEngine*>(cenv.engine());
	const AType* patT    = new AType(tag, 0, Cursor());
	Value*       typeV   = llVal(resp_compile(cenv, patT));

	return engine->builder.CreateICmp(CmpInst::ICMP_EQ, llVal(rtti), typeV);
}

CVal
LLVMEngine::compilePrimitive(CEnv& cenv, const ATuple* prim)
{
	ATuple::const_iterator i = prim->begin();

	LLVMEngine*  engine  = reinterpret_cast<LLVMEngine*>(cenv.engine());
	bool         isFloat = cenv.type(*++i)->str() == "Float";
	Value*       a       = llVal(resp_compile(cenv, *i++));
	Value*       b       = llVal(resp_compile(cenv, *i++));
	const string n       = prim->head()->to_symbol()->str();

	// Binary arithmetic operations
	Instruction::BinaryOps op = (Instruction::BinaryOps)0;
	if (n == "+")   op = Instruction::Add;
	if (n == "-")   op = Instruction::Sub;
	if (n == "*")   op = Instruction::Mul;
	if (n == "and") op = Instruction::And;
	if (n == "or")  op = Instruction::Or;
	if (n == "xor") op = Instruction::Xor;
	if (n == "/")   op = isFloat ? Instruction::FDiv : Instruction::SDiv;
	if (n == "%")   op = isFloat ? Instruction::FRem : Instruction::SRem;
	if (op != 0) {
		Value* val = engine->builder.CreateBinOp(op, a, b);
		while (i != prim->end())
			val = engine->builder.CreateBinOp(op, val, llVal(resp_compile(cenv, *i++)));
		return val;
	}

	// Comparison operations
	CmpInst::Predicate pred = (CmpInst::Predicate)0;
	if (n == "=")  pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ;
	if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ;
	if (n == ">")  pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT;
	if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE;
	if (n == "<")  pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT;
	if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE;
	if (pred != 0) {
		if (isFloat)
			return engine->builder.CreateFCmp(pred, a, b);
		else
			return engine->builder.CreateICmp(pred, a, b);
	}

	throw Error(prim->loc, "unknown primitive");
}

CVal
LLVMEngine::compileGlobalSet(CEnv& cenv, const string& sym, CVal val, const AType* type)
{
	LLVMEngine*     engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	GlobalVariable* global = new GlobalVariable(*module, llType(type), false,
			GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), sym);

	Value* valPtr = builder.CreateBitCast(llVal(val), llType(type), "globalPtr");

	engine->builder.CreateStore(valPtr, global);
	return global;
}

CVal
LLVMEngine::compileGlobalGet(CEnv& cenv, const string& sym, CVal val)
{
	LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	return engine->builder.CreateLoad(llVal(val), sym + "Ptr");
}

void
LLVMEngine::writeModule(CEnv& cenv, std::ostream& os)
{
	AssemblyAnnotationWriter writer;
	llvm::raw_os_ostream raw_stream(os);
	module->print(raw_stream, &writer);
}

const string
LLVMEngine::call(CEnv& cenv, CFunc f, const AType* retT)
{
	void*       fp = engine->getPointerToFunction(llFunc(f));
	const Type* t  = llType(retT);
	THROW_IF(!fp, Cursor(), "unable to get function pointer");
	THROW_IF(!t,  Cursor(), "function with non-concrete return type called");

	std::stringstream ss;
	if (t == Type::getInt32Ty(context)) {
		ss << ((int32_t (*)())fp)();
	} else if (t == Type::getFloatTy(context)) {
		ss << showpoint << ((float (*)())fp)();
	} else if (t == Type::getInt1Ty(context)) {
		ss << (((bool (*)())fp)() ? "#t" : "#f");
	} else if (retT->head()->str() == "String") {
		const std::string s(((char* (*)())fp)());
		ss << "\"";
		for (std::string::const_iterator i = s.begin(); i != s.end(); ++i) {
			switch (*i) {
			case '\"':
			case '\\':
				ss << '\\';
			default:
				ss << *i;
				break;
			}
		}
		ss << "\"";
	} else if (t != Type::getVoidTy(context)) {
		ss << ((void* (*)())fp)();
	} else {
		((void (*)())fp)();
	}
	return ss.str();
}

Engine*
resp_new_llvm_engine()
{
	return new LLVMEngine();
}