/* Resp Type Inferencing
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Constrain type of AST expressions
 */

#include <set>
#include "resp.hpp"

#define CONSTRAIN_LITERAL(CT, NAME) \
template<> void \
ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) const throw(Error) { \
	c.constrain(tenv, this, tenv.named(NAME)); \
}

// Literal template instantiations
CONSTRAIN_LITERAL(int32_t, "Int")
CONSTRAIN_LITERAL(float,   "Float")
CONSTRAIN_LITERAL(bool,    "Bool")

void
AString::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	c.constrain(tenv, this, tenv.named("String"));
}

void
ALexeme::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	c.constrain(tenv, this, tenv.named("Lexeme"));
}

void
ASymbol::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	const AType** ref = tenv.ref(this);
	THROW_IF(!ref, loc, (format("undefined symbol `%1%'") % cppstr).str());
	c.constrain(tenv, this, *ref);
}

void
ATuple::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	AType* t = tup<AType>(loc, NULL);
	FOREACHP(const_iterator, p, this) {
		(*p)->constrain(tenv, c);
		t->push_back(const_cast<AType*>(tenv.var(*p)));
	}
	c.constrain(tenv, this, t);
}

void
AFn::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	set<const ASymbol*> defs;
	TEnv::Frame frame;

	// Add parameters to environment frame
	AType* protT = tup<AType>(loc, NULL);
	for (ATuple::const_iterator i = prot()->begin(); i != prot()->end(); ++i) {
		const ASymbol* sym = (*i)->to<const ASymbol*>();
		THROW_IF(!sym, (*i)->loc, "parameter name is not a symbol");
		THROW_IF(defs.count(sym) != 0, sym->loc,
				(format("duplicate parameter `%1%'") % sym->str()).str());
		defs.insert(sym);
		const AType* tvar = tenv.fresh(sym);
		frame.push_back(make_pair(sym, tvar));
		protT->push_back(const_cast<AType*>(tvar));
	}

	const_iterator i = begin() + 1;
	c.constrain(tenv, *i, protT);

	// Add internal definitions to environment frame
	for (++i; i != end(); ++i) {
		const AST*  exp = *i;
		const ADef* def = exp->to<const ADef*>();
		if (def) {
			const ASymbol* sym = def->sym();
			THROW_IF(defs.count(sym) != 0, def->loc,
					(format("`%1%' defined twice") % sym->str()).str());
			defs.insert(def->sym());
			frame.push_back(make_pair(def->sym(), (AType*)NULL));
		}
	}

	tenv.push(frame);

	AST* exp = NULL;
	for (i = begin() + 2; i != end(); ++i)
		(exp = *i)->constrain(tenv, c);

	const AType* bodyT = tenv.var(exp);
	const AType* fnT   = tup<const AType>(loc, tenv.Fn, protT, bodyT, 0);
	Object::pool.addRoot(fnT);

	tenv.pop();

	c.constrain(tenv, this, fnT);
}

void
ACall::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	for (const_iterator i = begin(); i != end(); ++i)
		(*i)->constrain(tenv, c);

	const AType* fnType = tenv.var(head());
	if (fnType->kind != AType::VAR) {
		if (fnType->kind == AType::PRIM
				|| fnType->size() < 2
				|| fnType->head()->str() != "Fn")
			throw Error(loc, (format("call to non-function `%1%'") % head()->str()).str());

		size_t numArgs = fnType->prot()->size();
		THROW_IF(numArgs != size() - 1, loc,
				(format("expected %1% arguments, got %2%") % numArgs % (size() - 1)).str());
	}

	const AType* retT  = tenv.var(this);
	AType*       argsT = tup<AType>(loc, 0);
	for (const_iterator i = begin() + 1; i != end(); ++i)
		argsT->push_back(const_cast<AType*>(tenv.var(*i)));

	c.constrain(tenv, head(), tup<AType>(head()->loc, tenv.Fn, argsT, retT, 0));
	c.constrain(tenv, this, retT);
}

void
ADef::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	THROW_IF(size() != 3, loc, "`def' requires exactly 2 arguments");
	const ASymbol* sym = this->sym();
	THROW_IF(!sym, loc, "`def' has no symbol")

	const AType* tvar = tenv.var(body());
	tenv.def(sym, tvar);
	body()->constrain(tenv, c);
	c.constrain(tenv, sym, tvar);
	c.constrain(tenv, this, tenv.named("Nothing"));
}

void
AIf::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	THROW_IF(size() < 4, loc, "`if' requires at least 3 arguments");
	THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause")
	for (const_iterator i = begin() + 1; i != end(); ++i)
		(*i)->constrain(tenv, c);
	const AType* retT = tenv.var(this);
	for (const_iterator i = begin() + 1; true; ++i) {
		const_iterator next = i;
		++next;
		if (next == end()) { // final (else) expression
			c.constrain(tenv, *i, retT);
			break;
		} else {
			c.constrain(tenv, *i, tenv.named("Bool"));
			c.constrain(tenv, *next, retT);
		}
		i = next; // jump 2 each iteration (to the next predicate)
	}
}

void
ACons::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	AType* type = tup<AType>(loc, tenv.Tup, 0);
	for (const_iterator i = begin() + 1; i != end(); ++i) {
		(*i)->constrain(tenv, c);
		type->push_back(const_cast<AType*>(tenv.var(*i)));
	}

	c.constrain(tenv, this, type);
}

void
ADot::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	THROW_IF(size() != 3, loc, "`.' requires exactly 2 arguments");
	const_iterator     i   = begin();
	AST*               obj = *++i;
	ALiteral<int32_t>* idx = (*++i)->to<ALiteral<int32_t>*>();
	THROW_IF(!idx, loc, "the 2nd argument to `.' must be a literal integer");
	obj->constrain(tenv, c);

	const AType* retT = tenv.var(this);
	c.constrain(tenv, this, retT);

	AType* objT = tup<AType>(loc, tenv.Tup, 0);
	for (int i = 0; i < idx->val; ++i)
		objT->push_back(const_cast<AType*>(tenv.var()));
	objT->push_back(const_cast<AType*>(retT));
	objT->push_back(new AType(obj->loc, AType::DOTS));
	c.constrain(tenv, obj, objT);
}

void
APrimitive::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
	const string n = head()->to<const ASymbol*>()->str();
	enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type;
	if (n == "+" || n == "-" || n == "*" || n == "/")
		type = ARITHMETIC;
	else if (n == "%")
		type = BINARY;
	else if (n == "and" || n == "or" || n == "xor")
		type = LOGICAL;
	else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=")
		type = COMPARISON;
	else
		throw Error(loc, (format("unknown primitive `%1%'") % n).str());

	const_iterator i = begin();

	for (++i; i != end(); ++i)
		(*i)->constrain(tenv, c);

	i = begin();

	const AType* var = NULL;
	switch (type) {
	case ARITHMETIC:
		if (size() < 3)
			throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str());
		for (++i; i != end(); ++i)
			c.constrain(tenv, *i, tenv.var(this));
		break;
	case BINARY:
		if (size() != 3)
			throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
		c.constrain(tenv, *++i, tenv.var(this));
		c.constrain(tenv, *++i, tenv.var(this));
		break;
	case LOGICAL:
		if (size() != 3)
			throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
		c.constrain(tenv, this,  tenv.named("Bool"));
		c.constrain(tenv, *++i, tenv.named("Bool"));
		c.constrain(tenv, *++i, tenv.named("Bool"));
		break;
	case COMPARISON:
		if (size() != 3)
			throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
		var = tenv.var(*++i);
		c.constrain(tenv, this, tenv.named("Bool"));
		c.constrain(tenv, *++i, var);
		break;
	default:
		throw Error(loc, (format("unknown primitive `%1%'") % n).str());
	}
}