/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Lift functions (compilation pass 1)
 * After this pass:
 *  - All function definitions are top-level
 *  - All references to functions are replaced with references to
 *    a closure (a tuple with the function and necessary context)
 */

#include "resp.hpp"

using namespace std;

AST*
ASymbol::lift(CEnv& cenv, Code& code) throw()
{
	if (!cenv.liftStack.empty() && cppstr == cenv.liftStack.top().fn->name) {
		return cenv.penv.sym("me"); // Reference to innermost function
	} else if (!cenv.penv.handler(true, cppstr)
			&& !cenv.penv.handler(false, cppstr)
			&& !cenv.code.innermost(this)) {

		const int32_t index = cenv.liftStack.top().index(this);

		// Replace symbol with code to access free variable from closure
		return tup<ADot>(loc, cenv.penv.sym("."),
				cenv.penv.sym("me"),
				new ALiteral<int32_t>(index, Cursor()),
				NULL);
	} else {
		return this;
	}
}

AST*
AQuote::lift(CEnv& cenv, Code& code) throw()
{
	return this;
}

AST*
ATuple::lift(CEnv& cenv, Code& code) throw()
{
	ATuple*  ret = new ATuple(*this);
	iterator ri  = ret->begin();
	FOREACHP(const_iterator, t, this)
		*ri++ = (*t)->lift(cenv, code);
	cenv.setTypeSameAs(ret, this);
	return ret;
}

AST*
AFn::lift(CEnv& cenv, Code& code) throw()
{
	AFn* impl = new AFn(this);
	const string nameBase = cenv.penv.gensymstr(((name != "") ? name : "fn").c_str());
	impl->name = "_" + nameBase;

	cenv.liftStack.push(CEnv::FreeVars(this, impl->name));

	// Create a new stub environment frame for parameters
	cenv.push();
	const AType*          type      = cenv.type(this);
	AType::const_iterator tp        = type->prot()->begin();
	AType*                implProtT = new AType(*type->prot()->as<const AType*>());
	ATuple::iterator      ip        = implProtT->begin();
	for (const_iterator p = prot()->begin(); p != prot()->end(); ++p) {
		const AType* paramType = (*tp++)->as<const AType*>();
		if (paramType->kind == AType::EXPR && *paramType->head() == *cenv.tenv.Fn) {
			AType* fnType = new AType(*paramType);
			fnType->prot()->push_front(const_cast<AType*>(cenv.tenv.var()));
			paramType = tup<const AType>((*p)->loc, cenv.tenv.Tup, fnType, NULL);
		}
		cenv.def((*p)->as<ASymbol*>(), *p, paramType, NULL);
		*ip++ = new AType(*paramType);
	}

	/* Add closure parameter with dummy name (undefined symbol).
	 * The name of this parameter will be changed to the name of this
	 * function after lifting the body (so recursive references correctly
	 * refer to this function by the closure parameter).
	 */
	impl->prot()->push_front(cenv.penv.sym("_"));

	// Lift body
	const AType* implRetT = NULL;
	iterator ci = impl->begin() + 2;
	for (const_iterator i = begin() + 2; i != end(); ++i, ++ci) {
		*ci = (*i)->lift(cenv, code);
		implRetT = cenv.type(*ci);
	}

	cenv.pop();

	// Set name of closure parameter to "me"
	*impl->prot()->begin() = cenv.penv.sym("me");

	// Create definition for implementation fn
	ASymbol* implName = cenv.penv.sym(impl->name);
	ADef*    def       = tup<ADef>(loc, cenv.penv.sym("def"), implName, impl, NULL);
	code.push_back(def);

	AType* implT = new AType(*type); // Type of the implementation function
	AType* tupT  = tup<AType>(loc, cenv.tenv.Tup, cenv.tenv.var(), NULL);
	AType* consT = tup<AType>(loc, cenv.tenv.Tup, implT, NULL);
	ACons* cons  = tup<ACons>(loc, cenv.penv.sym("Closure"), implName, NULL);

	*(implT->begin() + 1) = implProtT;

	const CEnv::FreeVars& freeVars = cenv.liftStack.top();
	for (CEnv::FreeVars::const_iterator i = freeVars.begin(); i != freeVars.end(); ++i) {
		cons->push_back(*i);
		tupT->push_back(const_cast<AType*>(cenv.type(*i)));
		consT->push_back(const_cast<AType*>(cenv.type(*i)));
	}
	cenv.liftStack.pop();

	implT->prot()->push_front(tupT);
	*(implT->begin() + 2) = const_cast<AType*>(implRetT);

	cenv.setType(impl, implT);
	cenv.setType(cons, consT);

	cenv.def(implName, impl, implT, NULL);
	if (name != "")
		cenv.def(cenv.penv.sym(name), this, consT, NULL);

	return cons;
}

AST*
ACall::lift(CEnv& cenv, Code& code) throw()
{
	ACall*           copy = new ACall(this);
	ATuple::iterator ri   = copy->begin();

	// Lift all children (callee and arguments, recursively)
	for (const_iterator i = begin(); i != end(); ++i)
		*ri++ = (*i)->lift(cenv, code);

	ASymbol* sym = head()->to<ASymbol*>();
	if (sym && !cenv.liftStack.empty() && sym->cppstr == cenv.liftStack.top().fn->name) {
		/* Recursive call to innermost function, call implementation directly,
		 * reusing the current "me" closure parameter (no cons or .).
		 */
		copy->push_front(cenv.penv.sym(cenv.liftStack.top().implName));
	} else if (head()->to<AFn*>()) {
		/* Special case: ((fn ...) ...)
		 * Lifting (fn ...) yields: (Fn _impl ...).
		 * We don't want ((Fn _impl ...) (Fn _impl ...) ...),
		 * so call the implementation function (_impl) directly and pass the
		 * closure as the first parameter:
		 * (_impl (Fn _impl ...) ...)
		 */
		ACons*       closure = (*copy->begin())->as<ACons*>();
		ASymbol*     implSym = (*(closure->begin() + 1))->as<ASymbol*>();
		const AType* implT   = cenv.type(cenv.resolve(implSym));
		copy->push_front(implSym);
		cenv.setType(copy, (*(implT->begin() + 2))->as<const AType*>());
	} else {
		// Call to a closure, prepend code to access implementation function
		ADot* getFn = tup<ADot>(loc, cenv.penv.sym("."),
				copy->head(),
				new ALiteral<int32_t>(0, Cursor()), NULL);
		const AType* calleeT = cenv.type(copy->head());
		assert(**calleeT->begin() == *cenv.tenv.Tup);
		const AType* implT = (*(calleeT->begin() + 1))->as<const AType*>();
		copy->push_front(getFn);
		cenv.setType(getFn, implT);
		cenv.setType(copy, (*(implT->begin() + 2))->as<const AType*>());
	}

	return copy;
}

AST*
ADef::lift(CEnv& cenv, Code& code) throw()
{
	// Define stub first for recursion
	cenv.def(sym(), body(), cenv.type(body()), NULL);
	AFn* c = body()->to<AFn*>();
	if (c)
		c->name = sym()->str();

	ADef* copy = new ADef(ATuple::lift(cenv, code)->as<ATuple*>());

	if (copy->sym() == copy->body())
		return NULL; // Definition created by AFn::lift when body was lifted

	cenv.def(copy->sym(), copy->body(), cenv.type(copy->body()), NULL);
	cenv.setTypeSameAs(copy, this);
	return copy;
}

template<typename T>
AST*
lift_builtin_call(CEnv& cenv, T* call, Code& code) throw()
{
	ATuple*          copy = new T(call);
	ATuple::iterator ri   = copy->begin() + 1;

	// Lift all arguments
	for (typename T::const_iterator i = call->begin() + 1; i != call->end(); ++i)
		*ri++ = (*i)->lift(cenv, code);

	cenv.setTypeSameAs(copy, call);
	return copy;
}

AST* AIf::lift(CEnv& cenv, Code& code)        throw() { return lift_builtin_call(cenv, this, code); }
AST* ACons::lift(CEnv& cenv, Code& code)      throw() { return lift_builtin_call(cenv, this, code); }
AST* ADot::lift(CEnv& cenv, Code& code)       throw() { return lift_builtin_call(cenv, this, code); }
AST* APrimitive::lift(CEnv& cenv, Code& code) throw() { return lift_builtin_call(cenv, this, code); }