/* Tuplr Unification
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Unify type constraints
 */

#include <set>
#include "tuplr.hpp"

void
Constraints::constrain(TEnv& tenv, const AST* o, AType* t)
{
	assert(!o->to<const AType*>());
	push_back(Constraint(tenv.var(o), t, o->loc));
}

static void
substitute(ATuple* tup, const AST* from, AST* to)
{
	if (!tup) return;
	for (size_t i = 0; i < tup->size(); ++i)
		if (*tup->at(i) == *from)
			tup->at(i) = to;
		else if (tup->at(i) != to)
			substitute(tup->at(i)->to<ATuple*>(), from, to);
}

/// Compose two substitutions (TAPL 22.1.1)
Subst
Subst::compose(const Subst& delta, const Subst& gamma)
{
	Subst r;
	for (Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
		Subst::const_iterator d = delta.find(g->second);
		r.add(g->first, ((d != delta.end()) ? d : g)->second);
	}
	for (Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
		if (gamma.find(d->first) == gamma.end())
			r.add(d->first, d->second);
	}
	return r;
}

/// Replace all occurrences of @a s with @a t
void
Constraints::replace(AType* s, AType* t)
{
	for (Constraints::iterator c = begin(); c != end();) {
		Constraints::iterator next = c; ++next;
		if (*c->first  == *s) c->first  = t;
		if (*c->second == *s) c->second = t;
		substitute(c->first, s, t);
		substitute(c->second, s, t);
		c = next;
	}
}

/// Unify a type constraint set (TAPL 22.4)
Subst
unify(const Constraints& constraints)
{
	if (constraints.empty()) return Subst();
	AType*      s  = constraints.begin()->first;
	AType*      t  = constraints.begin()->second;
	Constraints cp = constraints;
	cp.erase(cp.begin());

	if (*s == *t) {
		return unify(cp);
	} else if (s->var() && !t->contains(s)) {
		cp.replace(s, t);
		return Subst::compose(unify(cp), Subst(s, t));
	} else if (t->var() && !s->contains(t)) {
		cp.replace(t, s);
		return Subst::compose(unify(cp), Subst(t, s));
	} else if (s->kind == AType::EXPR && s->kind == t->kind && s->size() == t->size()) {
		assert(*s->at(0)->to<ASymbol*>() == *t->at(0)->to<ASymbol*>());
		for (size_t i = 1; i < s->size(); ++i) {
			AType* si = s->at(i)->to<AType*>();
			AType* ti = t->at(i)->to<AType*>();
			assert(si && ti);
			cp.push_back(Constraint(si, ti, si->loc));
		}
		return unify(cp);
	} else {
		throw Error(s->loc ? s->loc : t->loc,
				(format("type is `%1%' but should be `%2%'") % s->str() % t->str()).str());
	}
}