/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Lift functions (compilation pass 1)
 * After this pass:
 *  - All function definitions are top-level
 *  - All references to functions are replaced with references to
 *    a closure (a tuple with the function and necessary context)
 */

#include "resp.hpp"

using namespace std;

static const AST*
lift_symbol(CEnv& cenv, Code& code, const ASymbol* sym) throw()
{
	const std::string& cppstr = sym->cppstr;
	if (!cenv.liftStack.empty() && cppstr == cenv.name(cenv.liftStack.top().fn)) {
		return cenv.penv.sym("_me"); // Reference to innermost function
	} else if (!cenv.code.innermost(sym)) {

		const int32_t index = cenv.liftStack.top().index(sym);

		// Replace symbol with code to access free variable from closure
		return tup<ATuple>(sym->loc, cenv.penv.sym("."),
		                   cenv.penv.sym("_me"),
		                   new ALiteral<int32_t>(T_INT32, index, Cursor()),
				NULL);
	} else {
		return sym;
	}
}

static const AST*
lift_fn(CEnv& cenv, Code& code, const ATuple* fn) throw()
{
	List<ATuple, const AST> impl;
	impl.push_back(fn->head());

	const string fnName      = cenv.name(fn);
	const string nameBase    = cenv.penv.gensymstr(((fnName != "") ? fnName : "fn").c_str());
	const string implNameStr = string("_") + nameBase;
	cenv.setName(impl, implNameStr);

	cenv.liftStack.push(CEnv::FreeVars(fn, implNameStr));
	
	// Create a new stub environment frame for parameters
	cenv.push();
	const AType*          type      = cenv.type(fn);
	AType::const_iterator tp        = type->prot()->begin();
	List<AType,AType>     implProtT;

	List<ATuple, const AST> implProt;

	// Prepend closure parameter
	implProt.push_back(cenv.penv.sym("_me"));

	for (ATuple::const_iterator p = fn->prot()->begin(); p != fn->prot()->end(); ++p) {
		const AType* paramType = (*tp++)->as_type();
		if (paramType->kind == AType::EXPR && *paramType->head() == *cenv.tenv.Fn) {
			const AType* fnType = new AType(cenv.tenv.var(), paramType, fnType->loc);
			paramType = tup<const AType>((*p)->loc, cenv.tenv.Tup, fnType, NULL);
		}
		cenv.def((*p)->as_symbol(), *p, paramType, NULL);
		implProt.push_back(*p);
		implProtT.push_back(new AType(*paramType));
	}

	impl.push_back(implProt);

	// Lift body
	const AType* implRetT = NULL;
	for (ATuple::const_iterator i = fn->iter_at(2); i != fn->end(); ++i) {
		const AST* lifted = resp_lift(cenv, code, *i);
		impl.push_back(lifted);
		implRetT = cenv.type(lifted);
	}

	cenv.pop();

	// Create definition for implementation fn
	ASymbol* implName = cenv.penv.sym(implNameStr);
	ATuple*  def      = tup<ATuple>(fn->loc, cenv.penv.sym("def"), implName, impl.head, NULL);
	code.push_back(def);

	TList implT; // Type of the implementation function
	TList tupT(fn->loc, cenv.tenv.Tup, cenv.tenv.var(), NULL);
	TList consT;
	List<ATuple, const AST> cons(fn->loc, cenv.penv.sym("Closure"), implName, NULL);

	const CEnv::FreeVars& freeVars = cenv.liftStack.top();
	for (CEnv::FreeVars::const_iterator i = freeVars.begin(); i != freeVars.end(); ++i) {
		cons.push_back(*i);
		tupT.push_back(cenv.type(*i));
		consT.push_back(cenv.type(*i));
	}
	cenv.liftStack.pop();

	implProtT.push_front(tupT);
	
	implT.push_back((AType*)type->head());
	implT.push_back(implProtT.head);
	implT.push_back(implRetT);

	consT.push_front(implT.head);
	consT.push_front(cenv.tenv.Tup);
	
	cenv.setType(impl, implT);
	cenv.setType(cons, consT);

	cenv.def(implName, impl, implT, NULL);
	if (cenv.name(fn) != "")
		cenv.def(cenv.penv.sym(cenv.name(fn)), fn, consT, NULL);

	return cons;
}

static const AST*
lift_call(CEnv& cenv, Code& code, const ATuple* call) throw()
{
	List<ATuple, const AST> copy;

	// Lift all children (callee and arguments, recursively)
	for (ATuple::const_iterator i = call->begin(); i != call->end(); ++i)
		copy.push_back(resp_lift(cenv, code, *i));

	copy.head->loc = call->loc;

	const AType* copyT = NULL;
	
	const ASymbol* sym = call->head()->to_symbol();
	if (sym && !cenv.liftStack.empty() && sym->cppstr == cenv.name(cenv.liftStack.top().fn)) {
		/* Recursive call to innermost function, call implementation directly,
		 * reusing the current "_me" closure parameter (no cons or .).
		 */
		copy.push_front(cenv.penv.sym(cenv.liftStack.top().implName));
	} else if (is_form(call, "fn")) {
		/* Special case: ((fn ...) ...)
		 * Lifting (fn ...) yields: (Fn _impl ...).
		 * We don't want ((Fn _impl ...) (Fn _impl ...) ...),
		 * so call the implementation function (_impl) directly and pass the
		 * closure as the first parameter:
		 * (_impl (Fn _impl ...) ...)
		 */
		const ATuple*  closure = copy.head->list_ref(0)->as_tuple();
		const ASymbol* implSym = closure->list_ref(1)->as_symbol();
		const AType*   implT   = cenv.type(cenv.resolve(implSym));
		copy.push_front(implSym);
		copyT = implT->list_ref(2)->as_type();
	} else {
		// Call to a closure, prepend code to access implementation function
		ATuple* getFn = tup<ATuple>(call->loc, cenv.penv.sym("."),
		                            copy.head->head(),
		                            new ALiteral<int32_t>(T_INT32, 0, Cursor()), NULL);
		const AType* calleeT = cenv.type(copy.head->head());
		assert(**calleeT->begin() == *cenv.tenv.Tup);
		const AType* implT = calleeT->list_ref(1)->as_type();
		copy.push_front(getFn);
		cenv.setType(getFn, implT);
		copyT = implT->list_ref(2)->as_type();
	}

	cenv.setType(copy, copyT);
	return copy;
}

static const AST*
lift_def(CEnv& cenv, Code& code, const ATuple* def) throw()
{
	// Define stub first for recursion
	const ASymbol* const sym  = def->list_ref(1)->as_symbol();
	const AST*     const body = def->list_ref(2);
	cenv.def(sym, body, cenv.type(body), NULL);
	if (is_form(body, "fn"))
		cenv.setName(body->as_tuple(), sym->str());

	assert(def->list_ref(1)->to_symbol());
	List<ATuple, const AST> copy;
	copy.push_back(def->head());
	copy.push_back(resp_lift(cenv, code, def->list_ref(1)));
	for (ATuple::const_iterator t = def->iter_at(2); t != def->end(); ++t)
		copy.push_back(resp_lift(cenv, code, *t));
	
	cenv.setTypeSameAs(copy.head, def);

	if (copy.head->list_ref(1) == copy.head->list_ref(2))
		return NULL; // Definition created by lift_fn when body was lifted

	cenv.def(copy.head->list_ref(1)->as_symbol(),
	         copy.head->list_ref(2),
	         cenv.type(copy.head->list_ref(2)),
	         NULL);
	return copy;
}

static const AST*
lift_builtin_call(CEnv& cenv, Code& code, const ATuple* call) throw()
{
	List<ATuple, const AST> copy;
	copy.push_back(call->head());

	// Lift all arguments
	for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i)
		copy.push_back(resp_lift(cenv, code, *i));
	
	cenv.setTypeSameAs(copy.head, call);

	return copy;
}

const AST*
resp_lift(CEnv& cenv, Code& code, const AST* ast) throw()
{
	const ASymbol* const sym = ast->to_symbol();
	if (sym)
		return lift_symbol(cenv, code, sym);

	const ATuple* const call = ast->to_tuple();
	if (call) {
		const ASymbol* const sym  = call->head()->to_symbol();
		const std::string    form = sym ? sym->cppstr : "";
		if (is_primitive(cenv.penv, call))
			return lift_builtin_call(cenv, code, call);
		else if (form == "fn")
			return lift_fn(cenv, code, call);
		else if (form == "def")
			return lift_def(cenv, code, call);
		else if (form == "if")
			return lift_builtin_call(cenv, code, call);
		else if (form == "cons" || isupper(form[0]))
			return lift_builtin_call(cenv, code, call);
		else if (form == ".")
			return lift_builtin_call(cenv, code, call);
		else if (form == "match" || form == "def-type")
			return call; // FIXME
		else
			return lift_call(cenv, code, call);
	}

	return ast;
}