/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <http://drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Compile to LLVM IR and/or execute via JIT
 */

#define __STDC_LIMIT_MACROS    1
#define __STDC_CONSTANT_MACROS 1

#include <map>
#include <sstream>
#include <string>
#include <vector>

#include <boost/format.hpp>

#include "llvm/Value.h"
#include "llvm/Analysis/Verifier.h"
#include "llvm/Assembly/AsmAnnotationWriter.h"
#include "llvm/DerivedTypes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JIT.h"
#include "llvm/Instructions.h"
#include "llvm/LLVMContext.h"
#include "llvm/Module.h"
#include "llvm/PassManager.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Target/TargetSelect.h"
#include "llvm/Transforms/Scalar.h"

#include "resp.hpp"

using namespace llvm;
using namespace std;
using boost::format;

/** LLVM Engine (Compiler and JIT) */
struct LLVMEngine : public Engine {
	LLVMEngine();
	virtual ~LLVMEngine();

	CFunc startFn(CEnv& cenv, const string& name, const ATuple* args, const ATuple* type);
	void  pushFnArgs(CEnv& cenv, const ATuple* prot, const ATuple* type, CFunc f);
	void  finishFn(CEnv& cenv, CFunc f, CVal ret);
	void  eraseFn(CEnv& cenv, CFunc f);

	CVal compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector<CVal>& args);
	CVal compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector<CVal>& fields);
	CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
	CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t);
	CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v);
	CVal compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse);
	CVal compileLiteral(CEnv& cenv, const AST* lit);
	CVal compilePrimitive(CEnv& cenv, const ATuple* prim);
	CVal compileString(CEnv& cenv, const char* str);

	void writeModule(CEnv& cenv, std::ostream& os);

	const string call(CEnv& cenv, CFunc f, const AST* retT);

private:
	void appendBlock(LLVMEngine* engine, Function* function, BasicBlock* block) {
		function->getBasicBlockList().push_back(block);
		engine->builder.SetInsertPoint(block);
	}

	inline Value*    llVal(CVal v)   { return static_cast<Value*>(v); }
	inline Function* llFunc(CFunc f) { return static_cast<Function*>(f); }
	const Type*      llType(const AST* t);

	LLVMContext          context;
	Module*              module;
	ExecutionEngine*     engine;
	IRBuilder<>          builder;
	Function*            alloc;
	FunctionPassManager* opt;

	unsigned labelIndex;
};

LLVMEngine::LLVMEngine()
	: builder(context)
	, labelIndex(1)
{
	InitializeNativeTarget();
	module = new Module("resp", context);
	engine = EngineBuilder(module).create();
	opt    = new FunctionPassManager(module);

	// Set up optimiser pipeline
	const TargetData* target = engine->getTargetData();
	opt->add(new TargetData(*target));          // Register target arch
	opt->add(createInstructionCombiningPass()); // Simple optimizations
	opt->add(createReassociatePass());          // Reassociate expressions
	opt->add(createGVNPass());                  // Eliminate Common Subexpressions
	opt->add(createCFGSimplificationPass());    // Simplify control flow

	// Declare host provided allocation primitive
	std::vector<const Type*> argsT(1, Type::getInt32Ty(context)); // unsigned size
	FunctionType* funcT = FunctionType::get(PointerType::get(Type::getInt8Ty(context), 0), argsT, false);
	alloc = Function::Create(funcT, Function::ExternalLinkage,
	                         "resp_gc_allocate", module);
}

LLVMEngine::~LLVMEngine()
{
	delete engine;
	delete opt;
}

const Type*
LLVMEngine::llType(const AST* t)
{
	if (t == NULL) {
		return NULL;
	} else if (AType::is_var(t)) {
		// Kludge for _me closure parameter, will be casted
		return PointerType::get(Type::getInt8Ty(context), NULL);
	} else if (AType::is_name(t)) {
		const std::string sym(t->as_symbol()->sym());
		if (sym == "Nothing") return Type::getVoidTy(context);
		if (sym == "Bool")    return Type::getInt1Ty(context);
		if (sym == "Int")     return Type::getInt32Ty(context);
		if (sym == "Float")   return Type::getFloatTy(context);
		if (sym == "String")  return PointerType::get(Type::getInt8Ty(context), NULL);
		if (sym == "Symbol")  return PointerType::get(Type::getInt8Ty(context), NULL);
	} else if (is_form(t, "Fn")) {
		ATuple::const_iterator i     = t->as_tuple()->begin();
		const ATuple*          protT = (*++i)->to_tuple();
		const AST*             retT  = (*++i);
		if (!llType(retT))
			return NULL;

		vector<const Type*> cprot;
		FOREACHP(ATuple::const_iterator, i, protT) {
			const Type* lt = llType(*i);
			if (!lt)
				return NULL;
			cprot.push_back(lt);
		}

		return PointerType::get(FunctionType::get(llType(retT), cprot, false), 0);
	} else if (AType::is_expr(t) && isupper(t->as_tuple()->fst()->str()[0])) {
		vector<const Type*> ctypes;
		ctypes.push_back(PointerType::get(Type::getInt8Ty(context), NULL)); // RTTI
		for (ATuple::const_iterator i = t->as_tuple()->iter_at(1); i != t->as_tuple()->end(); ++i) {
			const Type* lt = llType(*i);
			if (!lt)
				return NULL;
			ctypes.push_back(lt);
		}

		return PointerType::get(StructType::get(context, ctypes, false), 0);
	}
	return NULL;
}

/** Convert a size in bits to bytes, rounding up as necessary */
static inline size_t
bitsToBytes(size_t bits)
{
	return ((bits % 8 == 0) ? bits : (((bits / 8) + 1) * 8)) / 8;
}

CVal
LLVMEngine::compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector<CVal>& args)
{
	vector<Value*> llArgs(*reinterpret_cast<const vector<Value*>*>(&args));
	Value* closure = builder.CreateBitCast(llArgs[0],
	                                       llType(funcT->prot()->fst()),
	                                       cenv.penv.gensymstr("you"));
	llArgs[0] = closure;
	return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end());
}

CVal
LLVMEngine::compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector<CVal>& fields)
{
	// Find size of memory required
	size_t s = engine->getTargetData()->getTypeSizeInBits(
		PointerType::get(Type::getInt8Ty(context), NULL));
	assert(type->begin() != type->end());
	for (ATuple::const_iterator i = type->iter_at(1); i != type->end(); ++i)
		s += engine->getTargetData()->getTypeSizeInBits(llType(*i));

	// Allocate struct
	Value* structSize = ConstantInt::get(Type::getInt32Ty(context), bitsToBytes(s));
	Value* mem        = builder.CreateCall(alloc, structSize, "tupMem");
	Value* structPtr  = builder.CreateBitCast(mem, llType(type), "tup");

	// Set struct fields
	if (rtti)
		builder.CreateStore((Value*)rtti, builder.CreateStructGEP(structPtr, 0, "rtti"));
	size_t i = 1;
	for (vector<CVal>::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i) {
		builder.CreateStore(llVal(*f),
				builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str()));
	}

	return structPtr;
}

CVal
LLVMEngine::compileDot(CEnv& cenv, CVal tup, int32_t index)
{
	Value* ptr = builder.CreateStructGEP(llVal(tup), index, "dotPtr");
	return builder.CreateLoad(ptr, 0, "dotVal");
}

CVal
LLVMEngine::compileLiteral(CEnv& cenv, const AST* lit)
{
	switch (lit->tag()) {
	case T_BOOL:
		return ConstantInt::get(Type::getInt1Ty(context), ((const ALiteral<bool>*)lit)->val);
	case T_FLOAT:
		return ConstantFP::get(Type::getFloatTy(context), ((const ALiteral<float>*)lit)->val);
	case T_INT32:
		return ConstantInt::get(Type::getInt32Ty(context), ((const ALiteral<int32_t>*)lit)->val, true);
	default:
		throw Error(lit->loc, "Unknown literal type");
	}
}

CVal
LLVMEngine::compileString(CEnv& cenv, const char* str)
{
	return builder.CreateGlobalStringPtr(str);
}

CFunc
LLVMEngine::startFn(
	CEnv& cenv, const std::string& name, const ATuple* args, const ATuple* type)
{
	const ATuple* argsT = type->prot();
	const AST*    retT  = type->list_last();

	Function::LinkageTypes linkage = Function::ExternalLinkage;

	vector<const Type*> cprot;
	FOREACHP(ATuple::const_iterator, i, argsT) {
		THROW_IF(!llType(*i), Cursor(), string("non-concrete parameter :: ")
		         + (*i)->str())
			cprot.push_back(llType(*i));
	}

	THROW_IF(!llType(retT), Cursor(),
	         (format("return has non-concrete type `%1%'") % retT->str()).str());

	const string llName = (name == "") ? cenv.penv.gensymstr("_fn") : name;

	FunctionType* fT = FunctionType::get(llType(retT), cprot, false);
	Function*     f  = Function::Create(fT, linkage, llName, module);

	// Note f->getName() may be different from llName
	// however LLVM chooses to mangle is fine, we keep a pointer

	// Set argument names in generated code
	Function::arg_iterator a = f->arg_begin();
	for (ATuple::const_iterator i = args->begin(); i != args->end(); ++a, ++i)
		a->setName((*i)->as_symbol()->sym());

	BasicBlock* bb = BasicBlock::Create(context, "entry", f);
	builder.SetInsertPoint(bb);

	return f;
}

void
LLVMEngine::pushFnArgs(CEnv& cenv, const ATuple* prot, const ATuple* type, CFunc cfunc)
{
	cenv.push();

	const ATuple* argsT = type->prot();
	Function*     f     = llFunc(cfunc);

	// Bind argument values in CEnv
	ATuple::const_iterator p  = prot->begin();
	ATuple::const_iterator pT = argsT->begin();
	assert(prot->size() == argsT->size());
	assert(prot->size() == f->num_args());
	for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p, ++pT) {
		const AST* t  = cenv.resolveType(*pT);
		THROW_IF(!llType(t), (*p)->loc, "untyped parameter\n");
		cenv.def((*p)->as_symbol(), *p, t, &*a);
	}
}

void
LLVMEngine::finishFn(CEnv& cenv, CFunc f, CVal ret)
{
	builder.CreateRet(llVal(ret));
	if (verifyFunction(*static_cast<Function*>(f), llvm::PrintMessageAction)) {
		module->dump();
		throw Error(Cursor(), "Broken module");
	}
	if (cenv.args.find("-g") == cenv.args.end())
		opt->run(*static_cast<Function*>(f));
}

void
LLVMEngine::eraseFn(CEnv& cenv, CFunc f)
{
	if (f)
		llFunc(f)->eraseFromParent();
}

CVal
LLVMEngine::compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse)
{
	LLVMEngine* engine  = reinterpret_cast<LLVMEngine*>(cenv.engine());
	BasicBlock* mergeBB = BasicBlock::Create(context, "endif");
	Function*   parent  = engine->builder.GetInsertBlock()->getParent();
	BasicBlock* thenBB  = BasicBlock::Create(context, (format("then%1%") % labelIndex).str());
	BasicBlock* nextBB  = BasicBlock::Create(context, (format("else%1%") % labelIndex).str());

	const AST* type = cenv.type(then);
	
	++labelIndex;

	engine->builder.CreateCondBr(llVal(resp_compile(cenv, cond)), thenBB, nextBB);

	// Emit then block for this condition
	appendBlock(engine, parent, thenBB);
	Value* thenV = llVal(resp_compile(cenv, then));
	engine->builder.CreateBr(mergeBB);

	appendBlock(engine, parent, nextBB);

	Value* elseV = NULL;
	if (aelse)
		elseV = llVal(resp_compile(cenv, aelse));
	else
		elseV = Constant::getNullValue(llType(type));

	// Emit end of final else block
	engine->builder.CreateBr(mergeBB);
	BasicBlock* elseBB = engine->builder.GetInsertBlock();

	// Emit merge block (Phi node)
	appendBlock(engine, parent, mergeBB);
	PHINode* pn = engine->builder.CreatePHI(llType(type), "ifval");
	pn->addIncoming(thenV, thenBB);
	pn->addIncoming(elseV, elseBB);

	return pn;
}

CVal
LLVMEngine::compilePrimitive(CEnv& cenv, const ATuple* prim)
{
	ATuple::const_iterator i = prim->iter_at(1);

	LLVMEngine*  engine  = reinterpret_cast<LLVMEngine*>(cenv.engine());
	bool         isFloat = cenv.type(prim)->str() == "Float";
	Value*       a       = llVal(resp_compile(cenv, *i++));
	Value*       b       = llVal(resp_compile(cenv, *i++));
	const string n       = prim->fst()->to_symbol()->str();

	// Binary arithmetic operations
	Instruction::BinaryOps op = (Instruction::BinaryOps)0;
	if (n == "+")   op = Instruction::Add;
	if (n == "-")   op = Instruction::Sub;
	if (n == "*")   op = Instruction::Mul;
	if (n == "and") op = Instruction::And;
	if (n == "or")  op = Instruction::Or;
	if (n == "xor") op = Instruction::Xor;
	if (n == "/")   op = isFloat ? Instruction::FDiv : Instruction::SDiv;
	if (n == "%")   op = isFloat ? Instruction::FRem : Instruction::SRem;
	if (op != 0) {
		Value* val = engine->builder.CreateBinOp(op, a, b);
		while (i != prim->end())
			val = engine->builder.CreateBinOp(op, val, llVal(resp_compile(cenv, *i++)));
		return val;
	}

	// Comparison operations
	CmpInst::Predicate pred = (CmpInst::Predicate)0;
	if (n == "=")  pred = isFloat ? CmpInst::FCMP_OEQ : CmpInst::ICMP_EQ ;
	if (n == "!=") pred = isFloat ? CmpInst::FCMP_ONE : CmpInst::ICMP_NE ;
	if (n == ">")  pred = isFloat ? CmpInst::FCMP_OGT : CmpInst::ICMP_SGT;
	if (n == ">=") pred = isFloat ? CmpInst::FCMP_OGE : CmpInst::ICMP_SGE;
	if (n == "<")  pred = isFloat ? CmpInst::FCMP_OLT : CmpInst::ICMP_SLT;
	if (n == "<=") pred = isFloat ? CmpInst::FCMP_OLE : CmpInst::ICMP_SLE;
	if (pred != 0) {
		if (isFloat)
			return engine->builder.CreateFCmp(pred, a, b);
		else
			return engine->builder.CreateICmp(pred, a, b);
	}

	throw Error(prim->loc, "unknown primitive");
}

CVal
LLVMEngine::compileGlobalSet(CEnv& cenv, const string& sym, CVal val, const AST* type)
{
	LLVMEngine*     engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	GlobalVariable* global = new GlobalVariable(*module, llType(type), false,
			GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), sym);

	Value* valPtr = builder.CreateBitCast(llVal(val), llType(type), "globalPtr");

	engine->builder.CreateStore(valPtr, global);
	return global;
}

CVal
LLVMEngine::compileGlobalGet(CEnv& cenv, const string& sym, CVal val)
{
	LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
	return engine->builder.CreateLoad(llVal(val), sym + "Ptr");
}

void
LLVMEngine::writeModule(CEnv& cenv, std::ostream& os)
{
	AssemblyAnnotationWriter writer;
	llvm::raw_os_ostream raw_stream(os);
	module->print(raw_stream, &writer);
}

const string
LLVMEngine::call(CEnv& cenv, CFunc f, const AST* retT)
{
	void*       fp = engine->getPointerToFunction(llFunc(f));
	const Type* t  = llType(retT);
	THROW_IF(!fp, Cursor(), "unable to get function pointer");
	THROW_IF(!t,  Cursor(), "function with non-concrete return type called");

	std::stringstream ss;
	if (t == Type::getInt32Ty(context)) {
		ss << ((int32_t (*)())fp)();
	} else if (t == Type::getFloatTy(context)) {
		ss << showpoint << ((float (*)())fp)();
	} else if (t == Type::getInt1Ty(context)) {
		ss << (((bool (*)())fp)() ? "#t" : "#f");
	} else if (retT->str() == "String") {
		const std::string s(((char* (*)())fp)());
		ss << "\"";
		for (std::string::const_iterator i = s.begin(); i != s.end(); ++i) {
			switch (*i) {
			case '\"':
			case '\\':
				ss << '\\';
			default:
				ss << *i;
				break;
			}
		}
		ss << "\"";
	} else if (retT->str() == "Symbol") {
		const std::string s(((char* (*)())fp)());
		ss << s;
	} else if (t != Type::getVoidTy(context)) {
		ss << ((void* (*)())fp)();
	} else {
		((void (*)())fp)();
	}
	return ss.str();
}

Engine*
resp_new_llvm_engine()
{
	return new LLVMEngine();
}