/* Resp Unification
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Unify type constraints
 */

#include <set>
#include "resp.hpp"

/** Build a type substitution for calling a generic function type
 * with a specific set of argument types
 */
Subst
TEnv::buildSubst(const AType* genericT, const AType& argsT)
{
	Subst subst;

	// Build substitution to apply to generic type
	const ATuple* genericProtT = genericT->list_ref(1)->as_tuple();
	ATuple::const_iterator g = genericProtT->begin();
	ATuple::const_iterator a = argsT.begin();
	for (; a != argsT.end(); ++a, ++g) {
		const AType* genericArgT = (*g)->to_type();
		const AType* callArgT    = (*a)->to_type();
		if (callArgT->kind == AType::EXPR) {
			assert(genericArgT->kind == AType::EXPR);
			ATuple::const_iterator gi = genericArgT->begin();
			ATuple::const_iterator ci = callArgT->begin();
			for (; gi != genericArgT->end(); ++gi, ++ci) {
				const AType* gT = (*gi)->to_type();
				const AType* aT = (*ci)->to_type();
				if (gT && aT)
					subst.add(gT, aT);
			}
		} else {
			subst.add(genericArgT, callArgT);
		}
	}

	return subst;
}

void
Constraints::constrain(TEnv& tenv, const AST* o, const AType* t)
{
	assert(o);
	assert(t);
	push_back(Constraint(tenv.var(o), t));
}

static const AType*
substitute(const AType* tup, const AType* from, const AType* to)
{
	if (!tup) return NULL;
	TList ret;
	FOREACHP(AType::const_iterator, i, tup) {
		if (**i == *from) {
			AType* type = new AType(*to);
			type->loc = (*i)->loc;
			ret.push_back(type);
		} else if (*i != to) {
			const AType* elem = (*i)->as_type();
			if (elem->kind == AType::EXPR)
				ret.push_back(substitute(elem, from, to));
			else
				ret.push_back(elem);
		} else {
			ret.push_back((*i)->as_type());
		}
	}
	return ret.head;
}

/// Compose two substitutions (TAPL 22.1.1)
Subst
Subst::compose(const Subst& delta, const Subst& gamma)
{
	Subst r;
	for (Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
		Subst::const_iterator d = delta.find(g->second);
		r.add(g->first, ((d != delta.end()) ? d : g)->second);
	}
	for (Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
		if (gamma.find(d->first) == gamma.end())
			r.add(d->first, d->second);
	}
	return r;
}

/// Replace all occurrences of @a s with @a t
Constraints&
Constraints::replace(const AType* s, const AType* t)
{
	for (Constraints::iterator c = begin(); c != end(); ++c) {
		if (*c->first == *s) {
			c->first = new AType(*t, c->first->loc);
		} else if (c->first->kind == AType::EXPR) {
			c->first = substitute(c->first, s, t);
		}
		if (*c->second == *s) {
			c->second = new AType(*t, c->second->loc);
		} else if (c->second->kind == AType::EXPR) {
			c->second = substitute(c->second, s, t);
		}
	}
	return *this;
}

/// Unify a type constraint set (TAPL 22.4)
Subst
unify(const Constraints& constraints)
{
	if (constraints.empty())
		return Subst();

	Constraints::const_iterator i  = constraints.begin();
	const AType*                s  = i->first;
	const AType*                t  = i->second;
	Constraints                 cp(++i, constraints.end());

	if (*s == *t) {
		return unify(cp);
	} else if (s->kind == AType::VAR && !list_contains(t, s)) {
		return Subst::compose(unify(cp.replace(s, t)), Subst(s, t));
	} else if (t->kind == AType::VAR && !list_contains(s, t)) {
		return Subst::compose(unify(cp.replace(t, s)), Subst(t, s));
	} else if (s->kind == AType::EXPR && t->kind == AType::EXPR) {
		AType::const_iterator si = s->begin();
		AType::const_iterator ti = t->begin();
		for (; si != s->end() && ti != t->end(); ++si, ++ti) {
			const AType* st = (*si)->as_type();
			const AType* tt = (*ti)->as_type();
			if (st->kind == AType::DOTS || tt->kind == AType::DOTS)
				return unify(cp);
			else
				cp.push_back(Constraint(st, tt));
		}
		if (   (si == s->end() && (ti == t->end() || (*ti)->as_type()->kind == AType::DOTS))
		    || (ti == t->end() && (*si)->as_type()->kind == AType::DOTS))
			return unify(cp);
	}
	throw Error(s->loc,
			(format("type is `%1%' but should be `%2%'\n%3%: error: to match `%4%' here")
			 % s->str() % t->str() % t->loc.str() % t->str()).str());
}