/* Tuplr Type Inferencing
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "tuplr.hpp"
void
ASTTuple::constrain(TEnv& tenv) const
{
	AType* t = new AType(ASTTuple());
	FOREACH(const_iterator, p, *this) {
		(*p)->constrain(tenv);
		t->push_back(tenv.type(*p));
	}
	tenv.constrain(this, t);
}

void
ASTClosure::constrain(TEnv& tenv) const
{
	at(1)->constrain(tenv);
	at(2)->constrain(tenv);
	AType* protT = tenv.type(at(1));
	AType* bodyT = tenv.type(at(2));
	tenv.constrain(this, new AType(ASTTuple(
			tenv.penv.sym("Fn"), protT, bodyT, 0)));
}

void
ASTCall::constrain(TEnv& tenv) const
{
	FOREACH(const_iterator, p, *this)
		(*p)->constrain(tenv);
	AType* retT = tenv.type(this);
	AType* argsT = new AType(ASTTuple());
	for (size_t i = 1; i < size(); ++i)
		argsT->push_back(tenv.type(at(i)));
	tenv.constrain(at(0), new AType(ASTTuple(
			tenv.penv.sym("Fn"), argsT, retT, NULL)));
}

void
ASTDefinition::constrain(TEnv& tenv) const
{
	if (size() != 3)
		throw Error("`def' requires exactly 2 arguments", exp.loc);
	if (!dynamic_cast<const ASTSymbol*>(at(1)))
		throw Error("`def' name is not a symbol", exp.loc);
	FOREACH(const_iterator, p, *this)
		(*p)->constrain(tenv);
	AType* tvar = tenv.type(this);
	tenv.constrain(at(1), tvar);
	tenv.constrain(at(2), tvar);
}

void
ASTIf::constrain(TEnv& tenv) const
{
	FOREACH(const_iterator, p, *this)
		(*p)->constrain(tenv);
	AType* tvar = tenv.type(this);
	tenv.constrain(at(1), tenv.named("Bool"));
	tenv.constrain(at(2), tvar);
	tenv.constrain(at(3), tvar);
}

void
ASTConsCall::constrain(TEnv& tenv) const
{
	AType* t = new AType(ASTTuple(tenv.penv.sym("Pair"), 0));
	for (size_t i = 1; i < size(); ++i) {
		at(i)->constrain(tenv);
		t->push_back(tenv.type(at(i)));
	}
	tenv.constrain(this, t);
}

void
ASTCarCall::constrain(TEnv& tenv) const
{
	at(1)->constrain(tenv);
	AType* ct = tenv.var();
	AType* tt = new AType(ASTTuple(tenv.penv.sym("Pair"), ct, tenv.var(), 0));
	tenv.constrain(at(1), tt);
	tenv.constrain(this, ct);
}

void
ASTCdrCall::constrain(TEnv& tenv) const
{
	at(1)->constrain(tenv);
	AType* ct = tenv.var();
	AType* tt = new AType(ASTTuple(tenv.penv.sym("Pair"), tenv.var(), ct, 0));
	tenv.constrain(at(1), tt);
	tenv.constrain(this, ct);
}

static void
substitute(ASTTuple* tup, AST* from, AST* to)
{
	if (!tup) return;
	for (size_t i = 0; i < tup->size(); ++i)
		if (*tup->at(i) == *from)
			tup->at(i) = to;
		else
			substitute(dynamic_cast<ASTTuple*>(tup->at(i)), from, to);
}

bool
ASTTuple::contains(AST* child) const
{
	if (*this == *child) return true;
	FOREACH(const_iterator, p, *this)
		if (**p == *child || (*p)->contains(child))
			return true;
	return false;
}

TSubst
compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1
{
	TSubst r;
	for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
		TSubst::const_iterator d = delta.find(g->second);
		r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second));
	}
	for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
		if (gamma.find(d->first) == gamma.end())
			r.insert(*d);
	}
	return r;
}

void
substConstraints(TEnv::Constraints& constraints, AType* s, AType* t)
{
	for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) {
		TEnv::Constraints::iterator next = c; ++next;
		if (*c->first  == *s) c->first  = t;
		if (*c->second == *s) c->second = t;
		substitute(c->first, s, t);
		substitute(c->second, s, t);
		c = next;
	}
}

TSubst
TEnv::unify(const Constraints& constraints) // TAPL 22.4
{
	if (constraints.empty()) return TSubst();
	AType*      s  = constraints.begin()->first;
	AType*      t  = constraints.begin()->second;
	Constraints cp = constraints;
	cp.erase(cp.begin());

	if (*s == *t) {
		return unify(cp);
	} else if (s->var() && !t->contains(s)) {
		substConstraints(cp, s, t);
		return compose(unify(cp), TSubst(s, t));
	} else if (t->var() && !s->contains(t)) {
		substConstraints(cp, t, s);
		return compose(unify(cp), TSubst(t, s));
	} else if (s->size() == t->size()) {
		for (size_t i = 0; i < s->size(); ++i) {
			AType* si = dynamic_cast<AType*>(s->at(i));
			AType* ti = dynamic_cast<AType*>(t->at(i));
			if (si && ti)
				cp.push_back(make_pair(si, ti));
		}
		return unify(cp);
	} else {
		throw Error("Type unification failed");
	}
}

void
TEnv::apply(const TSubst& substs)
{
	FOREACH(TSubst::const_iterator, s, substs)
		FOREACH(Frame::iterator, t, front())
			if (*t->second == *s->first)
				t->second = s->second;
			else
				substitute(t->second, s->first, s->second);
}