/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <http://drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Convert AST to Continuation Passing Style
 */

#include <set>
#include <utility>
#include <vector>

#include "resp.hpp"

static bool
is_value(CEnv& cenv, const AST* exp)
{
	const ATuple* const call = exp->to_tuple();
	if (!call)
		return true; // Atom

	if (!is_primitive(cenv.penv, exp))
		return false; // Non-primitive fn call
	
	for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i)
		if (!is_value(cenv, *i))
			return false; // Primitive with non-value argument

	return true; // Primitive with all value arguments
}

/** [v]k => (k v) */
static const AST*
cps_value(CEnv& cenv, const AST* v, const AST* k)
{
	return tup(v->loc, k, v, 0);
}

/** [(fn (a ...) r)]k => (k (fn (a ... k2) [r]k2)) */
static const AST*
cps_fn(CEnv& cenv, const ATuple* fn, const AST* cont)
{
	const ASymbol* k2 = cenv.penv.gensym("__k");

	List copyProt;
	FOREACHP(ATuple::const_iterator, i, fn->prot())
		copyProt.push_back(*i);
	copyProt.push_back(k2);

	assert(fn->fst());
	assert(copyProt.head);
	List copy;
	copy.push_back(cenv.penv.sym("fn"));
	copy.push_back(copyProt);

	for (ATuple::const_iterator i = fn->iter_at(2); i != fn->end(); ++i)
		copy.push_back(resp_cps(cenv, *i, k2));

	return copy;
}

/** [(f a b ...)]k => [a](fn (__a) [b](fn (__b) ... (f __a __b ... k))) */
static const AST*
cps_call(CEnv& cenv, const ATuple* call, const AST* k)
{
	// Build innermost application first
	List body;
	typedef std::vector<const AST*> ExpVec;
	ExpVec exprs;
	ExpVec args;
	FOREACHP(ATuple::const_iterator, i, call) {
		exprs.push_back(*i);
		if (is_value(cenv, *i)) {
			body.push_back(*i);
			args.push_back(*i);
		} else {
			const ASymbol* sym = cenv.penv.gensym("__a");
			body.push_back(sym);
			args.push_back(sym);
		}
	}

	const AST* cont;
	if (cenv.penv.primitives.find(call->fst()->str()) != cenv.penv.primitives.end()) {
		cont = tup(Cursor(), k, body.head, 0);
	} else {
		body.push_back(k);
		cont = body;
	}

	// Wrap application in fns to evaluate parameters (from right to left)
	std::vector<const AST*>::const_reverse_iterator a = args.rbegin();
	for (ExpVec::const_reverse_iterator e = exprs.rbegin(); e != exprs.rend(); ++e, ++a) {
		if (!is_value(cenv, *e)) {
			cont = resp_cps(cenv, *e, tup(Cursor(), cenv.penv.sym("fn"),
			                              tup(Cursor(), *a, 0),
			                              cont,
			                              0));
		}
	}

	return cont;
}

/** [(def x y)]k => (def x [y]k) */
static const AST*
cps_def(CEnv& cenv, const ATuple* def, const AST* k)
{
	List copy(def->loc, def->fst(), def->frst(), 0);
	copy.push_back(resp_cps(cenv, def->list_ref(2), k));
	return copy;
	/*
	AST*    val     = body()->(tenv, cont);
	ATuple* valCall = val->to_tuple();
	ATuple::iterator i = valCall->begin();
	return tup(loc, tenv.penv.sym("def"), sym(), *++i, 0);
	*/
}

/** [(if c t e)]k => [c](fn (__c) (if c [t]k [e]k)) */
static const AST*
cps_if(CEnv& cenv, const ATuple* aif, const AST* k)
{
	ATuple::const_iterator i = aif->begin();
	const AST* const c = *++i;
	const AST* const t = *++i;
	const AST* const e = *++i;
	if (is_value(cenv, c)) {
		return tup(aif->loc, cenv.penv.sym("if"), c,
		           resp_cps(cenv, t, k),
		           resp_cps(cenv, e, k), 0);
	} else {
		/*
		  const ASymbol* const condSym = cenv.penv.gensym("c");
		  const ATuple* contFn = tup(loc, tenv.penv.sym("fn"),
		  tup(cond->loc, argSym, tenv.penv.gensym("_k"), 0),
		  tup(loc, tenv.penv.sym("if"), argSym,
		  exp->(tenv, cont),
		  next->(tenv, cont), 0));
		  return cond->(tenv, contFn);
		*/
		return aif;
	}
}

const AST*
resp_cps(CEnv& cenv, const AST* ast, const AST* k) throw()
{
	if (is_value(cenv, ast))
		return cps_value(cenv, ast, k);

	const ATuple* const call = ast->to_tuple();
	if (call) {
		const ASymbol* const sym  = call->fst()->to_symbol();
		const std::string    form = sym ? sym->sym() : "";
		if (form == "def")
			return cps_def(cenv, call, k);
		else if (form == "fn")
			return cps_fn(cenv, call, k);
		else if (form == "if")
			return cps_if(cenv, call, k);
		else
			return cps_call(cenv, call, k);
	}

	return cps_value(cenv, ast, k);
}