/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Compile to C
 */

#include <map>
#include <sstream>
#include <boost/format.hpp>
#include "resp.hpp"

using namespace std;
using boost::format;

typedef string Type;
typedef string Value;

struct Function {
	string returnType;
	string name;
	string text;
};

static inline Value*    llVal(CVal     v) { return static_cast<Value*>(v); }
static inline Function* llFunc(CFunc f)   { return static_cast<Function*>(f); }

static const Type*
llType(const AType* t)
{
	if (t == NULL) {
		return NULL;
	} else if (t->kind == AType::PRIM) {
		if (t->head()->str() == "Nothing") return new string("void");
		if (t->head()->str() == "Bool")    return new string("bool");
		if (t->head()->str() == "Int")     return new string("int");
		if (t->head()->str() == "Float")   return new string("float");
		if (t->head()->str() == "String")  return new string("char*");
		if (t->head()->str() == "Quote")   return new string("char*");
		if (t->head()->str() == "Lexeme")  return new string("char*");
		throw Error(t->loc, string("Unknown primitive type `") + t->str() + "'");
	} else if (t->kind == AType::EXPR && t->head()->str() == "Fn") {
		AType::const_iterator i     = t->begin();
		const ATuple*         protT = (*++i)->to_tuple();
		const AType*          retT  = (*i)->as_type();
		if (!llType(retT))
			return NULL;

		Type* ret = new Type(*llType(retT) + " (*)(");
		FOREACHP(ATuple::const_iterator, i, protT) {
			const AType* at = (*i)->to_type();
			const Type*  lt = llType(at);
			if (!lt)
				return NULL;
			*ret += *lt;
		}
		*ret += ")";

		return ret;
	} else if (t->kind == AType::EXPR && t->head()->str() == "Tup") {
		Type* ret = new Type("struct { void* me; ");
		for (AType::const_iterator i = t->iter_at(1); i != t->end(); ++i) {
			const Type* lt = llType((*i)->to_type());
			if (!lt)
				return NULL;
			ret->append("; ");
			ret->append(*lt);
		}
		ret->append("}*");
		return ret;
	}
	return new Type("void*");
}


/***************************************************************************
 * LLVM Engine                                                             *
 ***************************************************************************/

struct CEngine : public Engine {
	CEngine()
		: out(
			"#include <stdint.h>\n"
			"#include <stdbool.h>\n"
			"void* resp_gc_allocate(unsigned size, uint8_t tag);\n\n")
	{
	}

	CFunc startFunction(CEnv& cenv,
			const std::string& name, const ATuple* args, const AType* type)
	{
		const AType* argsT = type->prot()->as_type();
		const AType* retT  = type->list_ref(2)->as_type();

		vector<const Type*> cprot;
		FOREACHP(ATuple::const_iterator, i, argsT) {
			const AType* at = (*i)->as_type();
			THROW_IF(!llType(at), Cursor(), string("non-concrete parameter :: ")
				+ at->str())
			cprot.push_back(llType(at));
		}

		THROW_IF(!llType(retT), Cursor(),
				(format("return has non-concrete type `%1%'") % retT->str()).str());

		Function* f = new Function();
		f->returnType = *llType(retT);
		f->name = name;
		f->text += f->returnType + "\n" + f->name + "(";
		ATuple::const_iterator ai = argsT->begin();
		ATuple::const_iterator ni = args->begin();
		for (; ai != argsT->end(); ++ai, ++ni) {
			if (ai != argsT->begin())
				f->text += ", ";
			f->text += *llType((*ai)->as_type()) + " " + (*ni)->as_symbol()->cppstr;
		}
		f->text += ")\n{\n";

		out += f->text;
		return f;
	}

	void pushFunctionArgs(CEnv& cenv, const ATuple* fn, const AType* type, CFunc f);

	void finishFunction(CEnv& cenv, CFunc f, CVal ret) {
		out += "return " + *(Value*)ret + ";\n}\n\n";
	}

	void eraseFunction(CEnv& cenv, CFunc f) {
		cenv.err << "C backend does not support JIT (eraseFunction)" << endl;
	}

	CVal compileCall(CEnv& cenv, CFunc func, const AType* funcT, const vector<CVal>& args) {
		Value* varname = new string(cenv.penv.gensymstr("x"));
		Function* f = llFunc(func);
		out += (format("const %s %s = %s(") % f->returnType % *varname % f->name).str();
		FOREACH(vector<CVal>::const_iterator, i, args)
			out += *llVal(*i);
		out += ");\n";
		return varname;
	}

	CVal compileTup(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields);
	CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
	CVal compileLiteral(CEnv& cenv, const AST* lit);
	CVal compileString(CEnv& cenv, const char* str);
	CVal compilePrimitive(CEnv& cenv, const ATuple* prim);
	CVal compileIf(CEnv& cenv, const ATuple* aif);
	CVal compileMatch(CEnv& cenv, const ATuple* match);
	CVal compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal val);
	CVal getGlobal(CEnv& cenv, const string& sym, CVal val);

	void writeModule(CEnv& cenv, std::ostream& os) {
		os << out;
	}

	const string call(CEnv& cenv, CFunc f, const AType* retT) {
		cenv.err << "C backend does not support JIT (call)" << endl;
		return "";
	}

	std::string out;
};

Engine*
resp_new_c_engine()
{
	return new CEngine();
}

/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

CVal
CEngine::compileTup(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields)
{
	return NULL;
}

CVal
CEngine::compileDot(CEnv& cenv, CVal tup, int32_t index)
{
	return NULL;
}

CVal
CEngine::compileLiteral(CEnv& cenv, const AST* lit)
{
	return new Value(lit->str());
}

CVal
CEngine::compileString(CEnv& cenv, const char* str)
{
	return new Value(string("\"") + str + "\"");
}

void
CEngine::pushFunctionArgs(CEnv& cenv, const ATuple* fn, const AType* type, CFunc f)
{
	cenv.push();

	const AType* argsT = type->prot()->as_type();

	// Bind argument values in CEnv
	vector<Value*> args;
	ATuple::const_iterator    p  = fn->prot()->begin();
	ATuple::const_iterator pT = argsT->begin();
	for (; p != fn->prot()->end(); ++p, ++pT) {
		const AType* t = (*pT)->as_type();
		const Type* lt = llType(t);
		THROW_IF(!lt, fn->loc, "untyped parameter\n");
		cenv.def((*p)->as_symbol(), *p, t, new string((*p)->str()));
	}
}

CVal
CEngine::compileIf(CEnv& cenv, const ATuple* aif)
{
	Value* varname = new string(cenv.penv.gensymstr("if"));
	out += (format("%s %s;\n") % *llType(cenv.type(aif)) % *varname).str();
	size_t idx = 1;
	for (ATuple::const_iterator i = aif->iter_at(1); ; ++i, idx += 2) {
		ATuple::const_iterator next = i;
		if (++next == aif->end())
			break;

		if (idx > 1)
			out += "else {\n";

		Value* condV = llVal(resp_compile(cenv, *i));
		out += (format("if (%s) {\n") % *condV).str();

		Value* thenV = llVal(resp_compile(cenv, *next));
		out += (format("%s = %s;\n}\n") % *varname % *thenV).str();
	}

	// Emit final else block
	out += "else {\n";
	Value* elseV = llVal(resp_compile(cenv, aif->list_last()));
	out += (format("%s = %s;\n}\n") % *varname % *elseV).str();

	for (size_t i = 1; i < idx / 2; ++i)
		out += "}";

	return varname;
}

CVal
CEngine::compileMatch(CEnv& cenv, const ATuple* match)
{
	return NULL;
}

CVal
CEngine::compilePrimitive(CEnv& cenv, const ATuple* prim)
{
	ATuple::const_iterator i = prim->begin();
	++i;

	Value*       a  = llVal(resp_compile(cenv, *i++));
	Value*       b  = llVal(resp_compile(cenv, *i++));
	const string n  = prim->head()->to_symbol()->str();
	string       op = n;

	// Convert operator to C operator if they don't match
	if (n == "=")        op = "==";
	else if (n == "and") op = "&";
	else if (n == "or")  op = "|";
	else if (n == "xor") op = "^";

	op = string(" ") + op + " ";

	string val("(");
	val += *a + op + *b;
	while (i != prim->end())
		val += op + *llVal(resp_compile(cenv, *i++));
	val += ")";

	Value* varname = new string(cenv.penv.gensymstr("x"));
	out += (format("const %s %s = %s;\n") % *llType(cenv.type(prim)) % *varname % val).str();
	return varname;
}

CVal
CEngine::compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal val)
{
	return NULL;
}

CVal
CEngine::getGlobal(CEnv& cenv, const string& sym, CVal val)
{
	return NULL;
}