/* Tuplr Type Inferencing
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Convert AST to Continuation Passing Style
 */

#include <set>
#include "tuplr.hpp"

/** (cps x cont) => (cont x) */
AST*
AST::cps(TEnv& tenv, AST* cont)
{
	return tup<ACall>(loc, cont, this, 0);
}

/** (cps (fn (a ...) body) cont) => (cont (fn (a ... k) (cps body k)) */
AST*
AFn::cps(TEnv& tenv, AST* cont)
{
	ATuple*  copyProt = new ATuple(prot()->loc, *prot());
	ASymbol* contArg  = tenv.penv.gensym("_k");
	copyProt->push_back(contArg);
	AFn* copy = tup<AFn>(loc, tenv.penv.sym("fn"), copyProt, 0);
	const_iterator p = begin();
	++(++p);
	for (; p != end(); ++p)
		copy->push_back((*p)->cps(tenv, contArg));
	return tup<ACall>(loc, cont, copy, 0);
}

AST*
APrimitive::cps(TEnv& tenv, AST* cont)
{
	return value() ? tup<ACall>(loc, cont, this, 0) : ACall::cps(tenv, cont);
}

/** (cps (f a b ...)) => (a (fn (x) (b (fn (y) ... (cont (f x y ...)) */
AST*
ACall::cps(TEnv& tenv, AST* cont)
{
	std::vector< std::pair<AFn*, AST*> > funcs;
	AFn*     fn  = NULL;
	ASymbol* arg = NULL;

	// Make a continuation for each element (operator and arguments)
	// Argument evaluation continuations are not themselves in CPS.
	// Each makes a tail call to the next, and the last makes a tail
	// call to the continuation of this call
	ssize_t firstFn = -1;
	ssize_t index   = 0;
	FOREACH(iterator, i, *this) {
		if (!(*i)->to<ATuple*>()) {
			funcs.push_back(make_pair((AFn*)NULL, (*i)));
		} else {
			arg = tenv.penv.gensym("a");

			if (firstFn == -1)
				firstFn = index;

			AFn* thisFn = tup<AFn>(loc, tenv.penv.sym("fn"),
					tup<ATuple>((*i)->loc, arg, 0),
					0);

			if (fn)
				fn->push_back((*i)->cps(tenv, thisFn));

			funcs.push_back(make_pair(thisFn, arg));
			fn = thisFn;
		}
		++index;
	}

	if (firstFn != -1) {
		// Call this call's callee in the last argument evaluator
		ACall* call = tup<ACall>(loc, 0);
		assert(funcs.size() == size());
		for (size_t i = 0; i < funcs.size(); ++i)
			call->push_back(funcs[i].second);

		assert(fn);
		fn->push_back(call->cps(tenv, cont));
		return at(firstFn)->cps(tenv, funcs[firstFn].first);
	} else {
		assert(at(0)->value());
		ACall* ret = tup<ACall>(loc, 0);
		FOREACH(iterator, i, *this)
			ret->push_back((*i));
		if (!to<APrimitive*>())
			ret->push_back(cont);
		return ret;
	}
}

/** (cps (def x y)) => (y (fn (x) (cont))) */
AST*
ADef::cps(TEnv& tenv, AST* cont)
{
	AST*   val     = at(2)->cps(tenv, cont);
	ACall* valCall = val->to<ACall*>();
	assert(valCall);
	return tup<ADef>(loc, tenv.penv.sym("def"), sym(), valCall->at(1), 0);
}

/** (cps (if c t ... e)) => */
AST*
AIf::cps(TEnv& tenv, AST* cont)
{
	ASymbol* argSym = tenv.penv.gensym("c");
	if (at(1)->value()) {
		return tup<AIf>(loc, tenv.penv.sym("if"), at(1),
			at(2)->cps(tenv, cont),
			at(3)->cps(tenv, cont), 0);
	} else {
		AFn* contFn = tup<AFn>(loc, tenv.penv.sym("fn"),
				tup<ATuple>(at(1)->loc, argSym, tenv.penv.gensym("_k"), 0),
				tup<AIf>(loc, tenv.penv.sym("if"), argSym,
					at(2)->cps(tenv, cont),
					at(3)->cps(tenv, cont), 0));
		return at(1)->cps(tenv, contFn);
	}
}