/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <http://drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Lift functions (compilation pass 1)
 * After this pass:
 *  - All function definitions are top-level
 *  - All references to functions are replaced with references to
 *    a closure (a tuple with the function and necessary context)
 */

#include <algorithm>
#include <string>

#include "resp.hpp"

using namespace std;

static const AST*
lift_symbol(CEnv& cenv, Code& code, const ASymbol* sym) throw()
{
	if (!cenv.liftStack.empty()) {
		CEnv::FreeVars& vars = cenv.liftStack.top();
		if (cenv.name(vars.fn) == sym->sym()) {
			// Reference to innermost function, replace with "_me"
			return cenv.penv.sym("_me");

		} else if (!cenv.code.innermost(sym) && strcmp(sym->sym(), "__unreachable")) {
			/* Free variable, replace with "(. _me i)" where i is the index
			 * of the free variable in the closure.
			 * If this free variable hasn't been encountered yet, it is appended
			 * to the closure (the calling lift_fn will use cenv.liftStack.top()
			 * to construct the closure after the fn body has been lifted).
			 */
			const AST* dot = tup(sym->loc, cenv.penv.sym("."),
			           cenv.penv.sym("_me"),
			           new ALiteral<int32_t>(T_INT32, vars.index(sym) + 1, Cursor()),
			           NULL);
			cenv.setType(dot, cenv.type(sym));
			return dot;
		}
	}
	return sym;
}

static const AST*
lift_dot(CEnv& cenv, Code& code, const ATuple* dot) throw()
{
	const ALiteral<int32_t>* index = (ALiteral<int32_t>*)(dot->list_ref(2));
	List copy;
	copy.push_back(dot->fst());
	copy.push_back(resp_lift(cenv, code, dot->list_ref(1)));
	copy.push_back(new ALiteral<int32_t>(T_INT32, index->val + 1, Cursor())); // skip RTTI
	cenv.setTypeSameAs(copy, dot);
	return copy;
}

static const AST*
lift_def(CEnv& cenv, Code& code, const ATuple* def) throw()
{
	// Define stub first for recursion
	const ASymbol* const sym  = def->list_ref(1)->as_symbol();
	const AST*     const body = def->list_ref(2);
	cenv.def(sym, body, cenv.type(body), NULL);
	if (is_form(body, "lambda"))
		cenv.setName(body->as_tuple(), sym->str());

	assert(def->list_ref(1)->to_symbol());
	List copy;
	copy.push_back(def->fst());
	copy.push_back(sym);
	for (ATuple::const_iterator t = def->iter_at(2); t != def->end(); ++t)
		copy.push_back(resp_lift(cenv, code, *t));

	cenv.setTypeSameAs(copy.head, def);

	if (copy.head->list_ref(1) == copy.head->list_ref(2))
		return NULL; // Definition created by lift_fn when body was lifted

	cenv.def(copy.head->list_ref(1)->as_symbol(),
	         copy.head->list_ref(2),
	         cenv.type(copy.head->list_ref(2)),
	         NULL);
	return copy;
}

static const AST*
lift_fn(CEnv& cenv, Code& code, const ATuple* fn) throw()
{
	List impl;
	impl.push_back(fn->fst());

	const string fnName      = cenv.name(fn);
	const string implNameStr = (fnName != "")
		? (string("__") + fnName)
		: cenv.penv.gensymstr("__fn");

	cenv.setName(impl, implNameStr);

	cenv.liftStack.push(CEnv::FreeVars(fn, implNameStr));

	// Create a new stub environment frame for parameters
	cenv.push();
	const ATuple*          type = cenv.type(fn)->as_tuple();
	ATuple::const_iterator tp   = type->prot()->begin();

	List implProt;
	List implProtT;

	// Prepend closure parameter
	implProt.push_back(cenv.penv.sym("_me"));

	for (auto p : *fn->prot()) {
		const AST* paramType = (*tp++);
		if (is_form(paramType, "lambda")) {
			const ATuple* fnType = new ATuple(cenv.tenv.var(), paramType->as_tuple(), paramType->loc);
			paramType = tup(p->loc, cenv.tenv.Tup, fnType, NULL);
		}
		cenv.def(p->as_symbol(), p, paramType, NULL);
		implProt.push_back(p);
		implProtT.push_back(paramType);
	}

	// Write function prototype first for mutual and/or nested recursion
	List declProt(fn->loc, cenv.penv.sym("lambda"), 0);
	declProt.push_back(implProt);
	List decl(fn->loc, cenv.penv.sym("prot"), cenv.penv.sym(implNameStr), 0);
	decl.push_back(declProt);
	code.push_back(decl);
	cenv.setType(decl, cenv.penv.sym("Nothing"));

	impl.push_back(implProt);

	// Symbol for closure type (defined below)
	const ASymbol* tsym = cenv.penv.sym(
		(fnName != "") ? (string("__T") + fnName) : cenv.penv.gensymstr("__Tfn"));

	// Prepend closure parameter type
	implProtT.push_front(tsym);

	// Variable to represent our return type (for recursive lifting)
	const AST* retTVar = cenv.tenv.var();

	// Create definition for implementation fn
	const ASymbol* implName = cenv.penv.sym(implNameStr);
	const ATuple*  def      = tup(fn->loc, cenv.penv.sym("define"), implName, impl.head, NULL);

	// Define types before lifting body with return type as a variable
	List implT(Cursor(), type->fst(), implProtT.head, retTVar, 0);
	List closureT(Cursor(), cenv.tenv.Tup, implT.head, NULL);
	List cons(fn->loc, cenv.penv.sym("Closure"), implName, NULL);
	cenv.tenv.def(cenv.penv.sym(fnName), closureT);
	cenv.tenv.def(implName, implT);

	// Lift body
	const AST* implRetT = NULL;
	for (ATuple::const_iterator i = fn->iter_at(2); i != fn->end(); ++i) {
		const AST* lifted = resp_lift(cenv, code, *i);
		impl.push_back(lifted);
		implRetT = cenv.type(lifted);
	}

	cenv.pop();

	const CEnv::FreeVars freeVars = cenv.liftStack.top();
	cenv.liftStack.pop();
	for (auto v : freeVars) {
		cons.push_back(resp_lift(cenv, code, v));
		closureT.push_back(cenv.type(v));
	}

	// Now we know our real lifted return type
	const ATuple* realImplT = implT.head->replace(retTVar, implRetT);
	cenv.setType(impl, realImplT);

	// Create type definition for closure type
	const AST* tdef = resp_lift(
		cenv, code, tup(Cursor(), cenv.penv.sym("def-type"), tsym, closureT, 0));
	code.push_back(tdef);
	cenv.tenv.def(tsym, closureT);

	// Put forward declaration for type at start of code
	List tdecl(Cursor(), cenv.penv.sym("def-type"), tsym, 0);
	code.push_front(tdecl);

	// Set type of closure to type symbol
	cenv.setType(cons, tsym);

	// Emit implementation definition
	code.push_back(def);
	cenv.def(implName, impl, realImplT, NULL);
	if (cenv.name(fn) != "")
		cenv.def(cenv.penv.sym(cenv.name(fn)), fn, closureT, NULL);

	// Replace return type variable with actual return type in type environment
	for (auto& i : cenv.tenv) {
		for (auto& j : i) {
			if (j.second->to_tuple()) {
				j.second = j.second->as_tuple()->replace(retTVar, implRetT);
			}
		}
	}

	// Replace return type variable with actual return type in code
	for (auto& i : code) {
		if (is_form(i, "def-type")) {
			i = cenv.typedReplace(i->as_tuple(), retTVar, implRetT);
		}
	}

	return cons;
}

static const AST*
lift_call(CEnv& cenv, Code& code, const ATuple* call) throw()
{
	List copy;

	// Lift all children (callee and arguments, recursively)
	for (auto i : *call)
		copy.push_back(resp_lift(cenv, code, i));

	copy.head->loc = call->loc;

	const ASymbol* sym = call->fst()->to_symbol();
	if (sym && !cenv.liftStack.empty() && sym->sym() == cenv.name(cenv.liftStack.top().fn)) {
		/* Recursive call to innermost function, call implementation directly,
		 * reusing the current "_me" closure parameter (no cons or .).
		 */
		copy.push_front(cenv.penv.sym(cenv.liftStack.top().implName));
		copy.push_front(cenv.penv.sym("call"));
		cenv.setTypeSameAs(copy, call);
	} else if (is_form(call->fst(), "lambda")) {
		/* Special case: ((fn ...) ...)
		 * Lifting (fn ...) yields: (Closure _impl ...).
		 * We don't want (call (. (Closure _impl ...) 1) (Closure _impl ...) ...),
		 * so call the implementation function (_impl) directly and pass the
		 * closure as the first parameter:
		 * (call _impl (Closure _impl ...) ...)
		 */
		const ATuple*  closure = copy.head->list_ref(0)->as_tuple();
		const ASymbol* implSym = closure->list_ref(1)->as_symbol();
		const ATuple*  implT   = cenv.type(cenv.resolve(implSym))->as_tuple();
		copy.push_front(implSym);
		copy.push_front(cenv.penv.sym("call"));
		cenv.setType(copy, implT->list_ref(2));
	} else {
		// Call to a closure, prepend code to access implementation function
		ATuple* getFn = tup(call->loc, cenv.penv.sym("."),
		                    copy.head->fst(),
		                    new ALiteral<int32_t>(T_INT32, 1, Cursor()), NULL);
		const ATuple* calleeT = cenv.type(copy.head->fst())->as_tuple();
		assert(**calleeT->begin() == *cenv.tenv.Tup);
		const ATuple* implT = calleeT->list_ref(1)->as_tuple();
		copy.push_front(getFn);
		cenv.setType(getFn, implT);
		copy.push_front(cenv.penv.sym("call"));
		cenv.setType(copy, implT->list_ref(2));
	}

	return copy;
}

static const AST*
lift_args(CEnv& cenv, Code& code, const ATuple* call) throw()
{
	List copy;
	copy.push_back(call->fst());

	// Lift all arguments
	for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i)
		copy.push_back(resp_lift(cenv, code, *i));

	cenv.setTypeSameAs(copy.head, call);

	return copy;
}

const AST*
resp_lift(CEnv& cenv, Code& code, const AST* ast) throw()
{
	const ASymbol* const sym = ast->to_symbol();
	if (sym)
		return lift_symbol(cenv, code, sym);

	const ATuple* const call = ast->to_tuple();
	if (call) {
		const ASymbol* const sym  = call->fst()->to_symbol();
		const std::string    form = sym ? sym->sym() : "";
		if (is_primitive(cenv.penv, call))
			return lift_args(cenv, code, call);
		else if (form == "cons" || isupper(form[0]))
			return lift_args(cenv, code, call);
		else if (form == ".")
			return lift_dot(cenv, code, call);
		else if (form == "define")
			return lift_def(cenv, code, call);
		else if (form == "def-type")
			return call;
		else if (form == "do")
			return lift_args(cenv, code, call);
		else if (form == "lambda")
			return lift_fn(cenv, code, call);
		else if (form == "if")
			return lift_args(cenv, code, call);
		else if (form == "quote")
			return call;
		else
			return lift_call(cenv, code, call);
	}

	return ast;
}