/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <http://drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Lift functions (compilation pass 1)
 * After this pass:
 *  - All function definitions are top-level
 *  - All references to functions are replaced with references to
 *    a closure (a tuple with the function and necessary context)
 */

#include <algorithm>
#include <string>

#include "resp.hpp"

using namespace std;

static const AST*
lift_symbol(CEnv& cenv, Code& code, const ASymbol* sym) throw()
{
	if (!cenv.liftStack.empty()) {
		CEnv::FreeVars& vars = cenv.liftStack.top();
		if (cenv.name(vars.fn) == sym->sym()) {
			// Reference to innermost function, replace with "_me"
			return cenv.penv.sym("_me");

		} else if (!cenv.code.innermost(sym) && strcmp(sym->sym(), "__unreachable")) {
			/* Free variable, replace with "(. _me i)" where i is the index
			 * of the free variable in the closure.
			 * If this free variable hasn't been encountered yet, it is appended
			 * to the closure (the calling lift_fn will use cenv.liftStack.top()
			 * to construct the closure after the fn body has been lifted).
			 */
			return tup(sym->loc, cenv.penv.sym("."),
			           cenv.penv.sym("_me"),
			           new ALiteral<int32_t>(T_INT32, vars.index(sym) + 1, Cursor()),
			           NULL);
		}
	}
	return sym;
}

static const AST*
lift_dot(CEnv& cenv, Code& code, const ATuple* dot) throw()
{
	const ALiteral<int32_t>* index = (ALiteral<int32_t>*)(dot->list_ref(2));
	List copy;
	copy.push_back(dot->fst());
	copy.push_back(resp_lift(cenv, code, dot->list_ref(1)));
	copy.push_back(new ALiteral<int32_t>(T_INT32, index->val + 1, Cursor())); // skip RTTI
	cenv.setTypeSameAs(copy, dot);
	return copy;
}

static const AST*
lift_def(CEnv& cenv, Code& code, const ATuple* def) throw()
{
	// Define stub first for recursion
	const ASymbol* const sym  = def->list_ref(1)->as_symbol();
	const AST*     const body = def->list_ref(2);
	cenv.def(sym, body, cenv.type(body), NULL);
	if (is_form(body, "fn"))
		cenv.setName(body->as_tuple(), sym->str());

	assert(def->list_ref(1)->to_symbol());
	List copy;
	copy.push_back(def->fst());
	copy.push_back(resp_lift(cenv, code, def->list_ref(1)));
	for (ATuple::const_iterator t = def->iter_at(2); t != def->end(); ++t)
		copy.push_back(resp_lift(cenv, code, *t));

	cenv.setTypeSameAs(copy.head, def);

	if (copy.head->list_ref(1) == copy.head->list_ref(2))
		return NULL; // Definition created by lift_fn when body was lifted

	cenv.def(copy.head->list_ref(1)->as_symbol(),
	         copy.head->list_ref(2),
	         cenv.type(copy.head->list_ref(2)),
	         NULL);
	return copy;
}

static const AST*
lift_def_type(CEnv& cenv, Code& code, const ATuple* def) throw()
{
	const ASymbol* sym = def->frst()->to_symbol();
	if (!sym)
		return def;

	const AST* type = def->frrst()->as_tuple()->replace(sym, cenv.penv.sym("__REC"));
	return tup(def->loc, def->fst(), sym, type, 0);
}

static const AST*
lift_fn(CEnv& cenv, Code& code, const ATuple* fn) throw()
{
	List impl;
	impl.push_back(fn->fst());

	const string fnName      = cenv.name(fn);
	const string implNameStr = (fnName != "")
		? (string("__") + fnName)
		: cenv.penv.gensymstr("__fn");

	cenv.setName(impl, implNameStr);

	cenv.liftStack.push(CEnv::FreeVars(fn, implNameStr));

	// Create a new stub environment frame for parameters
	cenv.push();
	const ATuple*          type = cenv.type(fn)->as_tuple();
	ATuple::const_iterator tp   = type->prot()->begin();

	List implProt;
	List implProtT;

	// Prepend closure parameter
	implProt.push_back(cenv.penv.sym("_me"));

	for (ATuple::const_iterator p = fn->prot()->begin(); p != fn->prot()->end(); ++p) {
		const AST* paramType = (*tp++);
		if (is_form(paramType, "Fn")) {
			const ATuple* fnType = new ATuple(cenv.tenv.var(), paramType->as_tuple(), fnType->loc);
			paramType = tup((*p)->loc, cenv.tenv.Tup, fnType, NULL);
		}
		cenv.def((*p)->as_symbol(), *p, paramType, NULL);
		implProt.push_back(*p);
		implProtT.push_back(paramType);
	}

	impl.push_back(implProt);

	// Lift body
	const AST* implRetT = NULL;
	for (ATuple::const_iterator i = fn->iter_at(2); i != fn->end(); ++i) {
		const AST* lifted = resp_lift(cenv, code, *i);
		impl.push_back(lifted);
		implRetT = cenv.type(lifted);
	}

	cenv.pop();

	// Symbol for closure type (defined below)
	const ASymbol* tsym = cenv.penv.sym(
		(fnName != "") ? (string("__T") + fnName) : cenv.penv.gensymstr("__Tfn"));

	// Create definition for implementation fn
	const ASymbol* implName = cenv.penv.sym(implNameStr);
	const ATuple*  def      = tup(fn->loc, cenv.penv.sym("def"), implName, impl.head, NULL);

	List tupT(fn->loc, cenv.tenv.Tup, cenv.tenv.var(), NULL);
	List consT;
	List cons(fn->loc, cenv.penv.sym("Closure"), implName, NULL);

	const CEnv::FreeVars& freeVars = cenv.liftStack.top();
	for (CEnv::FreeVars::const_iterator i = freeVars.begin(); i != freeVars.end(); ++i) {
		cons.push_back(*i);
		tupT.push_back(cenv.type(*i));
		consT.push_back(cenv.type(*i));
	}
	cenv.liftStack.pop();

	// Prepend closure parameter type
	implProtT.push_front(tsym);

	const ATuple* implT = tup(Cursor(), type->fst(), implProtT.head, implRetT, 0);

	consT.push_front(implT);
	consT.push_front(cenv.tenv.Tup);

	cenv.setType(impl, implT);

	// Create type definition for closure type
	const AST* tdef = resp_lift(
		cenv, code, tup(Cursor(), cenv.penv.sym("def-type"), tsym, consT.head, 0));
	code.push_back(tdef);
	cenv.tenv.def(tsym, consT);

	code.push_back(def);

	// Set type of closure to type symbol
	cenv.setType(cons, tsym);

	cenv.def(implName, impl, implT, NULL);
	if (cenv.name(fn) != "")
		cenv.def(cenv.penv.sym(cenv.name(fn)), fn, consT, NULL);


	return cons;
}

static const AST*
lift_call(CEnv& cenv, Code& code, const ATuple* call) throw()
{
	List copy;

	// Lift all children (callee and arguments, recursively)
	for (ATuple::const_iterator i = call->begin(); i != call->end(); ++i)
		copy.push_back(resp_lift(cenv, code, *i));

	copy.head->loc = call->loc;

	const ASymbol* sym = call->fst()->to_symbol();
	if (sym && !cenv.liftStack.empty() && sym->sym() == cenv.name(cenv.liftStack.top().fn)) {
		/* Recursive call to innermost function, call implementation directly,
		 * reusing the current "_me" closure parameter (no cons or .).
		 */
		copy.push_front(cenv.penv.sym(cenv.liftStack.top().implName));
		cenv.setTypeSameAs(copy, call);
	} else if (is_form(call->fst(), "fn")) {
		/* Special case: ((fn ...) ...)
		 * Lifting (fn ...) yields: (Fn _impl ...).
		 * We don't want ((Fn _impl ...) (Fn _impl ...) ...),
		 * so call the implementation function (_impl) directly and pass the
		 * closure as the first parameter:
		 * (_impl (Fn _impl ...) ...)
		 */
		const ATuple*  closure = copy.head->list_ref(0)->as_tuple();
		const ASymbol* implSym = closure->list_ref(1)->as_symbol();
		const ATuple*  implT   = cenv.type(cenv.resolve(implSym))->as_tuple();
		copy.push_front(implSym);
		cenv.setType(copy, implT->list_ref(2));
	} else {
		// Call to a closure, prepend code to access implementation function
		ATuple* getFn = tup(call->loc, cenv.penv.sym("."),
		                    copy.head->fst(),
		                    new ALiteral<int32_t>(T_INT32, 1, Cursor()), NULL);
		const ATuple* calleeT = cenv.type(copy.head->fst())->as_tuple();
		assert(**calleeT->begin() == *cenv.tenv.Tup);
		const ATuple* implT = calleeT->list_ref(1)->as_tuple();
		copy.push_front(getFn);
		cenv.setType(getFn, implT);
		cenv.setType(copy, implT->list_ref(2));
	}

	return copy;
}

static const AST*
lift_args(CEnv& cenv, Code& code, const ATuple* call) throw()
{
	List copy;
	copy.push_back(call->fst());

	// Lift all arguments
	for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i)
		copy.push_back(resp_lift(cenv, code, *i));

	cenv.setTypeSameAs(copy.head, call);

	return copy;
}

const AST*
resp_lift(CEnv& cenv, Code& code, const AST* ast) throw()
{
	const ASymbol* const sym = ast->to_symbol();
	if (sym)
		return lift_symbol(cenv, code, sym);

	const ATuple* const call = ast->to_tuple();
	if (call) {
		const ASymbol* const sym  = call->fst()->to_symbol();
		const std::string    form = sym ? sym->sym() : "";
		if (is_primitive(cenv.penv, call))
			return lift_args(cenv, code, call);
		else if (form == "cons" || isupper(form[0]))
			return lift_args(cenv, code, call);
		else if (form == ".")
			return lift_dot(cenv, code, call);
		else if (form == "def")
			return lift_def(cenv, code, call);
		else if (form == "def-type")
			return lift_def_type(cenv, code, call);
		else if (form == "do")
			return lift_args(cenv, code, call);
		else if (form == "fn")
			return lift_fn(cenv, code, call);
		else if (form == "if")
			return lift_args(cenv, code, call);
		else if (form == "quote" || form == "cast")
			return call;
		else
			return lift_call(cenv, code, call);
	}

	return ast;
}