/* Tuplr Type Inferencing
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Constrain type of AST expressions
 */

#include <set>
#include "tuplr.hpp"

void
AString::constrain(TEnv& tenv, Constraints& c) const
{
	c.push_back(Constraint(tenv.var(this), tenv.named("String"), loc));
}

void
ASymbol::constrain(TEnv& tenv, Constraints& c) const
{
	addr = tenv.lookup(this);
	if (!addr)
		throw Error(loc, (format("undefined symbol `%1%'") % cppstr).str());
	c.push_back(Constraint(tenv.var(this), tenv.deref(addr).second, loc));
}

void
ATuple::constrain(TEnv& tenv, Constraints& c) const
{
	AType* t = tup<AType>(loc, NULL);
	FOREACH(const_iterator, p, *this) {
		(*p)->constrain(tenv, c);
		t->push_back(tenv.var(*p));
	}
	c.push_back(Constraint(tenv.var(this), t, loc));
}

void
AFn::constrain(TEnv& tenv, Constraints& c) const
{
	const AType* genericType;
	TEnv::GenericTypes::const_iterator gt = tenv.genericTypes.find(this);
	if (gt != tenv.genericTypes.end()) {
		genericType = gt->second;
	} else {
		set<const ASymbol*> defined;
		TEnv::Frame frame;

		// Add parameters to environment frame
		for (size_t i = 0; i < prot()->size(); ++i) {
			const ASymbol* sym = prot()->at(i)->to<const ASymbol*>();
			if (!sym)
				throw Error(prot()->at(i)->loc, "parameter name is not a symbol");
			if (defined.find(sym) != defined.end())
				throw Error(sym->loc, (format("duplicate parameter `%1%'") % sym->str()).str());
			defined.insert(sym);
			frame.push_back(make_pair(sym, make_pair((AST*)NULL, (AType*)NULL)));
		}

		// Add internal definitions to environment frame
		size_t e = 2;
		for (; e < size(); ++e) {
			const AST*  exp = at(e);
			const ADef* def = exp->to<const ADef*>();
			if (def) {
				const ASymbol* sym = def->sym();
				if (defined.find(sym) != defined.end())
					throw Error(def->loc, (format("`%1%' defined twice") % sym->str()).str());
				defined.insert(def->sym());
				frame.push_back(make_pair(def->sym(),
						make_pair(const_cast<AST*>(def->at(2)), (AType*)NULL)));
			}
		}

		tenv.push(frame);

		Constraints cp;
		cp.push_back(Constraint(tenv.var(this), tenv.var(), loc));

		AType* protT = tup<AType>(loc, NULL);
		size_t s = 0;
		for (ATuple::const_iterator i = prot()->begin(); i != prot()->end(); ++i, ++s) {
			AType* tvar = tenv.fresh((*i)->to<ASymbol*>());
			protT->push_back(tvar);
			assert(frame[s].first == (*i));
			frame[s].second.first = (*i);
			frame[s].second.second = tvar;
		}
		c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc));

		for (size_t i = 2; i < size(); ++i)
			at(i)->constrain(tenv, cp);

		AType* bodyT = tenv.var(at(e-1));
		Subst tsubst = unify(cp);
		genericType = tup<AType>(loc, tenv.penv.sym("Fn"),
				tsubst.apply(protT), tsubst.apply(bodyT), 0);
		tenv.genericTypes.insert(make_pair(this, genericType));
		Object::pool.addRoot(genericType);

		tenv.pop();
		subst = tsubst;
	}

	AType* t = new AType(*genericType); // FIXME: deep copy
	c.constrain(tenv, this, t);
}

void
ACall::constrain(TEnv& tenv, Constraints& c) const
{
	at(0)->constrain(tenv, c);
	for (size_t i = 1; i < size(); ++i)
		at(i)->constrain(tenv, c);

	const AST* callee  = tenv.resolve(at(0));
	const AFn* closure = callee->to<const AFn*>();
	if (closure) {
		if (size() - 1 != closure->prot()->size())
			throw Error(loc, "incorrect number of arguments");
		TEnv::GenericTypes::iterator gt = tenv.genericTypes.find(closure);
		if (gt != tenv.genericTypes.end()) {
			const ATuple*  prot   = gt->second->at(1)->to<const ATuple*>();
			const_iterator i      = begin();
			const_iterator prot_i = prot->begin();
			for (++i; i != end(); ++i, ++prot_i)
				c.constrain(tenv, *i, (*prot_i)->as<AType*>());
			AType* retT = tenv.var(this);
			c.constrain(tenv, at(0),
					tup<AType>(at(0)->loc, tenv.penv.sym("Fn"), tenv.var(), retT, 0));
			c.constrain(tenv, this, retT);
			return;
		}
	}
	AType* argsT = tup<AType>(loc, 0);
	for (size_t i = 1; i < size(); ++i)
		argsT->push_back(tenv.var(at(i)));
	AType* retT = tenv.var();
	c.constrain(tenv, at(0), tup<AType>(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
	c.constrain(tenv, this, retT);
}

void
ADef::constrain(TEnv& tenv, Constraints& c) const
{
	THROW_IF(size() != 3, loc, "`def' requires exactly 2 arguments");
	const ASymbol* sym = this->sym();
	THROW_IF(!sym, loc, "`def' has no symbol")

	AType* tvar = tenv.var(at(2));
	tenv.def(sym, make_pair(const_cast<AST*>(at(2)), tvar));
	at(2)->constrain(tenv, c);
	c.constrain(tenv, this, tvar);
}

void
AIf::constrain(TEnv& tenv, Constraints& c) const
{
	THROW_IF(size() < 4, loc, "`if' requires at least 3 arguments");
	THROW_IF(size() % 2 != 0, loc, "`if' missing final else clause")
	for (size_t i = 1; i < size(); ++i)
		at(i)->constrain(tenv, c);
	AType* retT = tenv.var(this);
	for (size_t i = 1; i < size(); i += 2) {
		if (i == size() - 1) {
			c.constrain(tenv, at(i), retT);
		} else {
			c.constrain(tenv, at(i), tenv.named("Bool"));
			c.constrain(tenv, at(i+1), retT);
		}
	}
}

void
APrimitive::constrain(TEnv& tenv, Constraints& c) const
{
	const string n = at(0)->to<const ASymbol*>()->str();
	enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type;
	if (n == "+" || n == "-" || n == "*" || n == "/")
		type = ARITHMETIC;
	else if (n == "%")
		type = BINARY;
	else if (n == "and" || n == "or" || n == "xor")
		type = LOGICAL;
	else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=")
		type = COMPARISON;
	else
		throw Error(loc, (format("unknown primitive `%1%'") % n).str());

	for (size_t i = 1; i < size(); ++i)
		at(i)->constrain(tenv, c);

	switch (type) {
	case ARITHMETIC:
		if (size() < 3)
			throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str());
		for (size_t i = 1; i < size(); ++i)
			c.constrain(tenv, at(i), tenv.var(this));
		break;
	case BINARY:
		if (size() != 3)
			throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
		c.constrain(tenv, at(1), tenv.var(this));
		c.constrain(tenv, at(2), tenv.var(this));
		break;
	case LOGICAL:
		if (size() != 3)
			throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
		c.constrain(tenv, this,  tenv.named("Bool"));
		c.constrain(tenv, at(1), tenv.named("Bool"));
		c.constrain(tenv, at(2), tenv.named("Bool"));
		break;
	case COMPARISON:
		if (size() != 3)
			throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
		c.constrain(tenv, this, tenv.named("Bool"));
		c.constrain(tenv, at(1), tenv.var(at(2)));
		break;
	default:
		throw Error(loc, (format("unknown primitive `%1%'") % n).str());
	}
}