/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <http://drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Constrain type of AST expressions
 */

#include <set>
#include <string>

#include "resp.hpp"

static void
constrain_symbol(TEnv& tenv, Constraints& c, const ASymbol* sym) throw(Error)
{
	const AST** ref = tenv.ref(sym);
	THROW_IF(!ref, sym->loc, (format("undefined symbol `%1%'") % sym->sym()).str());
	c.constrain(tenv, sym, *ref);
}

static void
constrain_cons(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	const ASymbol* sym  = (*call->begin())->as_symbol();
	const AST*     type = NULL;

	for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i)
		resp_constrain(tenv, c, *i);

	if (!strcmp(sym->sym(), "Tup")) {
		List tupT(new ATuple(tenv.Tup, NULL, call->loc));
		for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i) {
			tupT.push_back(tenv.var(*i));
		}
		type = tupT;
	} else {
		const AST** consTRef = tenv.ref(sym);
		THROW_IF(!consTRef, call->loc,
		         (format("call to undefined constructor `%1%'") % sym->sym()).str());
		const AST* consT = *consTRef;
		type = new ATuple(consT->as_tuple()->fst(), 0, call->loc);
	}
	c.constrain(tenv, call, type);
}

static void
constrain_dot(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	THROW_IF(call->list_len() != 3, call->loc, "`.' requires exactly 2 arguments");
	ATuple::const_iterator   i   = call->begin();
	const AST*               obj = *++i;
	const AST*               idx = *++i;
	THROW_IF(idx->tag() != T_INT32, call->loc, "the 2nd argument to `.' must be a literal integer");

	resp_constrain(tenv, c, obj);

	const AST* retT = tenv.var(call);
	c.constrain(tenv, call, retT);

	List objT(new ATuple(tenv.Tup, NULL, call->loc));
	for (int i = 0; i < ((ALiteral<int32_t>*)idx)->val; ++i)
		objT.push_back(tenv.var());
	objT.push_back(retT);
	objT.push_back(tenv.Dots);
	c.constrain(tenv, obj, objT);
}

static void
constrain_def(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	THROW_IF(call->list_len() != 3, call->loc, "`def' requires exactly 2 arguments");
	THROW_IF(!call->frst()->to_symbol(), call->frst()->loc, "`def' name is not a symbol");
	const ASymbol* const sym = call->list_ref(1)->as_symbol();
	THROW_IF(!sym, call->loc, "`def' has no symbol")
	const AST* const body = call->list_ref(2);

	const AST* tvar = tenv.var(body);
	tenv.def(sym, tvar);
	resp_constrain(tenv, c, body);
	c.constrain(tenv, sym, tvar);
	c.constrain(tenv, call, tenv.named("Nothing"));
}

static void
constrain_def_type(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	THROW_IF(call->list_len() < 3, call->loc, "`def-type' requires at least 2 arguments");
	ATuple::const_iterator i = call->iter_at(1);
	const ATuple* prot = (*i)->to_tuple();
	THROW_IF(!prot, (*i)->loc, "first argument of `def-type' is not a tuple");
	const ASymbol* sym = (*prot->begin())->as_symbol();
	THROW_IF(!sym, (*prot->begin())->loc, "type name is not a symbol");
	THROW_IF(tenv.ref(sym), call->loc, "type redefinition");
	List type(new ATuple(tenv.U, NULL, call->loc));
	for (ATuple::const_iterator i = call->iter_at(2); i != call->end(); ++i) {
		const ATuple*  exp   = (*i)->as_tuple();
		const ASymbol* tag   = (*exp->begin())->as_symbol();
		List           consT;
		consT.push_back(sym);
		for (ATuple::const_iterator i = exp->begin(); i != exp->end(); ++i) {
			consT.push_back(*i); // FIXME: ensure symbol, or list of symbol
		}
		consT.head->loc = exp->loc;
		type.push_back(consT);
		tenv.def(tag, consT);
	}
	tenv.def(sym, type);
}

static void
constrain_fn(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	set<const ASymbol*> defs;
	TEnv::Frame frame;

	// Add parameters to environment frame
	List protT;
	FOREACHP(ATuple::const_iterator, i, call->prot()) {
		const ASymbol* sym = (*i)->to_symbol();
		THROW_IF(!sym, (*i)->loc, "parameter name is not a symbol");
		THROW_IF(defs.count(sym) != 0, sym->loc,
				(format("duplicate parameter `%1%'") % sym->str()).str());
		defs.insert(sym);
		const AST* tvar = tenv.fresh(sym);
		frame.push_back(make_pair(sym->sym(), tvar));
		protT.push_back(tvar);
	}
	protT.head->loc = call->loc;

	ATuple::const_iterator i = call->iter_at(1);
	c.constrain(tenv, *i, protT);

	// Add internal definitions to environment frame
	for (++i; i != call->end(); ++i) {
		const AST*    exp  = *i;
		const ATuple* call = exp->to_tuple();
		if (call && is_form(call, "def")) {
			const ASymbol* sym = call->list_ref(1)->as_symbol();
			THROW_IF(defs.count(sym) != 0, call->loc,
					(format("`%1%' defined twice") % sym->str()).str());
			defs.insert(sym);
			frame.push_back(make_pair(sym->sym(), (AST*)NULL));
		}
	}

	tenv.push(frame);

	const AST* exp = NULL;
	for (i = call->iter_at(2); i != call->end(); ++i) {
		exp = *i;
		resp_constrain(tenv, c, exp);
	}

	const AST*    bodyT = tenv.var(exp);
	const ATuple* fnT   = tup(call->loc, tenv.Fn, protT.head, bodyT, 0);
	Object::pool.addRoot(fnT);

	tenv.pop();

	c.constrain(tenv, call, fnT);
}

static void
constrain_if(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	THROW_IF(call->list_len() < 4, call->loc, "`if' requires at least 3 arguments");
	THROW_IF(call->list_len() % 2 != 0, call->loc, "`if' missing final else clause");
	for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i)
		resp_constrain(tenv, c, *i);
	const AST* retT = tenv.var(call);
	for (ATuple::const_iterator i = call->iter_at(1); true; ++i) {
		ATuple::const_iterator next = i;
		++next;
		if (next == call->end()) { // final (else) expression
			c.constrain(tenv, *i, retT);
			break;
		} else {
			c.constrain(tenv, *i, tenv.named("Bool"));
			c.constrain(tenv, *next, retT);
		}
		i = next; // jump 2 each iteration (to the next predicate)
	}
}

static void
constrain_let(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	THROW_IF(call->list_len() < 3, call->loc, "`let' requires at least 2 arguments");
	const ATuple* vars = call->list_ref(1)->to_tuple();
	THROW_IF(!vars, call->list_ref(1)->loc, "first argument of `let' is not a list");

	TEnv::Frame frame;
	for (ATuple::const_iterator i = vars->begin(); i != vars->end(); ++i) {
		const ASymbol* sym = (*i)->to_symbol();
		THROW_IF(!sym, (*i)->loc, "`let' binding name is not a symbol");
		ATuple::const_iterator val = ++i;
		THROW_IF(val == vars->end(), sym->loc, "`let' variable missing value");

		resp_constrain(tenv, c, *val);
		const AST* tvar = tenv.var(*val);
		frame.push_back(make_pair(sym->sym(), tvar));
		c.constrain(tenv, sym, tvar);
		//c.constrain(tenv, *val, tvar);
	}
	
	tenv.push(frame);

	for (ATuple::const_iterator i = call->iter_at(2); i != call->end(); ++i)
		resp_constrain(tenv, c, *i);

	c.constrain(tenv, call, tenv.var(call->list_last()));
	
	tenv.pop();
}

static void
constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	THROW_IF(call->list_len() < 5, call->loc, "`match' requires at least 4 arguments");
	const AST* matchee  = call->list_ref(1);
	const AST* retT     = tenv.var();
	const AST* matcheeT = NULL;
	resp_constrain(tenv, c, matchee);
	for (ATuple::const_iterator i = call->iter_at(2); i != call->end();) {
		const AST*    exp     = *i++;
		const ATuple* pattern = exp->to_tuple();
		THROW_IF(!pattern, exp->loc, "pattern expression expected");
		const ASymbol* name = (*pattern->begin())->to_symbol();
		THROW_IF(!name, (*pattern->begin())->loc, "pattern does not start with a symbol");
		THROW_IF(!tenv.ref(name), name->loc,
		         (format("undefined constructor `%1%'") % name->sym()).str());

		const AST* consT = *tenv.ref(name);

		if (!matcheeT) {
			const AST* headT = consT->as_tuple()->fst();
			matcheeT = new ATuple(headT, 0, call->loc);
		}

		THROW_IF(i == call->end(), pattern->loc, "missing pattern body");
		const AST* body = *i++;

		TEnv::Frame frame;
		ATuple::const_iterator ti = consT->as_tuple()->iter_at(2);
		for (ATuple::const_iterator pi = pattern->iter_at(1); pi != pattern->end(); ++pi)
			frame.push_back(make_pair((*pi)->as_symbol()->sym(), *ti++));

		tenv.push(frame);
		resp_constrain(tenv, c, body);
		c.constrain(tenv, body, retT);
		tenv.pop();
	}
	c.constrain(tenv, call, retT);
	c.constrain(tenv, matchee, matcheeT);
}

static void
resp_constrain_quoted(TEnv& tenv, Constraints& c, const AST* ast) throw(Error)
{
	switch (ast->tag()) {
	case T_SYMBOL:
		c.constrain(tenv, ast, tenv.named("Symbol"));
		return;
	case T_TUPLE:
		c.constrain(tenv, ast, tenv.named("List"));
		FOREACHP(ATuple::const_iterator, i, ast->as_tuple())
			resp_constrain_quoted(tenv, c, *i);
		return;
	default:
		resp_constrain(tenv, c, ast);
	}
}

static void
constrain_quote(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	THROW_IF(call->list_len() != 2, call->loc, "`quote' requires exactly 1 argument");
	resp_constrain_quoted(tenv, c, call->frst());
	c.constrain(tenv, call, tenv.var(call->frst()));
}

static void
constrain_call(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	const AST* const head = call->fst();
	
	for (ATuple::const_iterator i = call->begin(); i != call->end(); ++i)
		resp_constrain(tenv, c, *i);

	const AST* fnType = tenv.var(head);
	if (!AType::is_var(fnType)) {
		if (!is_form(fnType, "Fn"))
			throw Error(call->loc, (format("call to non-function `%1%'") % head->str()).str());

		size_t numArgs = fnType->as_tuple()->prot()->list_len();
		THROW_IF(numArgs != call->list_len() - 1, call->loc,
				(format("expected %1% arguments, got %2%") % numArgs % (call->list_len() - 1)).str());
	}

	const AST* retT  = tenv.var(call);
	List       argsT;
	for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i)
		argsT.push_back(tenv.var(*i));
	argsT.head->loc = call->loc;
	c.constrain(tenv, head, tup(head->loc, tenv.Fn, argsT.head, retT, 0));
	c.constrain(tenv, call, retT);
}

static void
constrain_primitive(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
{
	const string n = call->fst()->to_symbol()->str();
	enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type;
	if (n == "+" || n == "-" || n == "*" || n == "/")
		type = ARITHMETIC;
	else if (n == "%")
		type = BINARY;
	else if (n == "and" || n == "or" || n == "xor")
		type = LOGICAL;
	else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=")
		type = COMPARISON;
	else
		throw Error(call->loc, (format("unknown primitive `%1%'") % n).str());

	ATuple::const_iterator i = call->begin();

	for (++i; i != call->end(); ++i)
		resp_constrain(tenv, c, *i);

	i = call->begin();

	const AST* var = NULL;
	switch (type) {
	case ARITHMETIC:
		if (call->list_len() < 3)
			throw Error(call->loc, (format("`%1%' requires at least 2 arguments") % n).str());
		for (++i; i != call->end(); ++i)
			c.constrain(tenv, *i, tenv.var(call));
		break;
	case BINARY:
		if (call->list_len() != 3)
			throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str());
		c.constrain(tenv, *++i, tenv.var(call));
		c.constrain(tenv, *++i, tenv.var(call));
		break;
	case LOGICAL:
		if (call->list_len() != 3)
			throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str());
		c.constrain(tenv, call, tenv.named("Bool"));
		c.constrain(tenv, *++i, tenv.named("Bool"));
		c.constrain(tenv, *++i, tenv.named("Bool"));
		break;
	case COMPARISON:
		if (call->list_len() != 3)
			throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str());
		var = tenv.var(*++i);
		c.constrain(tenv, call, tenv.named("Bool"));
		c.constrain(tenv, *++i, var);
		break;
	default:
		throw Error(call->loc, (format("unknown primitive `%1%'") % n).str());
	}
}

static void
constrain_list(TEnv& tenv, Constraints& c, const ATuple* tup) throw(Error)
{
	const ASymbol* const sym = tup->fst()->to_symbol();
	if (!sym) {
		constrain_call(tenv, c, tup);
		return;
	}

	const std::string form = sym->sym();
	if (is_primitive(tenv.penv, tup))
		constrain_primitive(tenv, c, tup);
	else if (form == "cons" || isupper(form[0]))
		constrain_cons(tenv, c, tup);
	else if (form == ".")
		constrain_dot(tenv, c, tup);
	else if (form == "def")
		constrain_def(tenv, c, tup);
	else if (form == "def-type")
		constrain_def_type(tenv, c, tup);
	else if (form == "fn")
		constrain_fn(tenv, c, tup);
	else if (form == "if")
		constrain_if(tenv, c, tup);
	else if (form == "let")
		constrain_let(tenv, c, tup);
	else if (form == "match")
		constrain_match(tenv, c, tup);
	else if (form == "quote")
		constrain_quote(tenv, c, tup);
	else
		constrain_call(tenv, c, tup);
}

void
resp_constrain(TEnv& tenv, Constraints& c, const AST* ast) throw(Error)
{
	switch (ast->tag()) {
	case T_UNKNOWN:
	case T_TVAR:
		break;
	case T_BOOL:
		c.constrain(tenv, ast, tenv.named("Bool"));
		break;
	case T_FLOAT:
		c.constrain(tenv, ast, tenv.named("Float"));
		break;
	case T_INT32:
		c.constrain(tenv, ast, tenv.named("Int"));
		break;
	case T_STRING:
		c.constrain(tenv, ast, tenv.named("String"));
		break;
	case T_SYMBOL:
	case T_LITSYM:
		constrain_symbol(tenv, c, ast->as_symbol());
		break;
	case T_TUPLE:
		constrain_list(tenv, c, ast->as_tuple());
		break;
	}
}