/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <http://drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Unify type constraints
 */

#include <set>
#include "resp.hpp"

/** Build a type substitution for calling a generic function type
 * with a specific set of argument types
 */
Subst
TEnv::buildSubst(const AST* genericT, const AST& argsT)
{
	Subst subst;

	// Build substitution to apply to generic type
	const ATuple* genericProtT = genericT->as_tuple()->list_ref(1)->as_tuple();
	ATuple::const_iterator g = genericProtT->begin();
	ATuple::const_iterator a = argsT.as_tuple()->begin();
	for (; a != argsT.as_tuple()->end(); ++a, ++g) {
		if (AType::is_expr(*a)) {
			assert(AType::is_expr(*g));
			ATuple::const_iterator gi = (*g)->as_tuple()->begin();
			ATuple::const_iterator ci = (*a)->as_tuple()->begin();
			for (; gi != (*g)->as_tuple()->end(); ++gi, ++ci) {
				if ((*gi) && (*ci))
					subst.add(*gi, *ci);
			}
		} else {
			subst.add(*g, *a);
		}
	}

	return subst;
}

void
Constraints::constrain(TEnv& tenv, const AST* o, const AST* t)
{
	assert(o);
	assert(t);
	push_back(Constraint(tenv.var(o), t));
}

static const AST*
substitute(const AST* in, const AST* from, const AST* to)
{
	if (in == from)
		return to;

	const ATuple* tup = in->to_tuple();
	if (!tup)
		return from;
	
	List ret;
	FOREACHP(ATuple::const_iterator, i, tup->as_tuple()) {
		if (**i == *from) {
			ret.push_back(to); // FIXME: should be a copy w/ (*i)->loc
		} else if (*i != to) {
			if (AType::is_expr(*i))
				ret.push_back(substitute(*i, from, to));
			else
				ret.push_back(*i);
		} else {
			ret.push_back(*i);
		}
	}
	return ret.head;
}

/// Compose two substitutions (TAPL 22.1.1)
Subst
Subst::compose(const Subst& delta, const Subst& gamma)
{
	Subst r;
	for (Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
		Subst::const_iterator d = delta.find(g->second);
		r.add(g->first, ((d != delta.end()) ? d : g)->second);
	}
	for (Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
		if (gamma.find(d->first) == gamma.end())
			r.add(d->first, d->second);
	}
	return r;
}

/// Replace all occurrences of @a s with @a t
Constraints&
Constraints::replace(const AST* s, const AST* t)
{
	for (Constraints::iterator c = begin(); c != end(); ++c) {
		if (*c->first == *s) {
			c->first = t; // FIXME: should be copy w/ c->first->loc;
		} else if (AType::is_expr(c->first)) {
			c->first = substitute(c->first, s, t);
		}
		if (*c->second == *s) {
			c->second = t; // FIXME: should be copy w/ c->second->loc;
		} else if (AType::is_expr(c->second)) {
			c->second = substitute(c->second, s, t);
		}
	}
	return *this;
}

static inline bool
is_dots(const AST* type)
{
	return (AType::is_name(type) && type->as_symbol()->str() == "...");
}

/// Unify a type constraint set (TAPL 22.4)
Subst
unify(const Constraints& constraints)
{
	if (constraints.empty())
		return Subst();

	Constraints::const_iterator i  = constraints.begin();
	const AST*                  s  = i->first;
	const AST*                  t  = i->second;
	Constraints                 cp(++i, constraints.end());

	if (*s == *t) {
		return unify(cp);
	} else if (AType::is_var(s) && !list_contains(t->to_tuple(), s)) {
		return Subst::compose(unify(cp.replace(s, t)), Subst(s, t));
	} else if (AType::is_var(t) && !list_contains(s->to_tuple(), t)) {
		return Subst::compose(unify(cp.replace(t, s)), Subst(t, s));
	} else if (AType::is_expr(s) && AType::is_expr(t)) {
		const ATuple* const st = s->as_tuple();
		const ATuple* const tt = t->as_tuple();
		ATuple::const_iterator si = st->begin();
		ATuple::const_iterator ti = tt->begin();
		for (; si != st->end() && ti != tt->end(); ++si, ++ti) {
			if (is_dots(*si) || is_dots(*ti))
				return unify(cp);
			else
				cp.push_back(Constraint(*si, *ti));
		}
		if ((si == st->end() && ti == tt->end())
		    || (si != st->end() && is_dots(*si))
		    || (ti != tt->end() && is_dots(*ti)))
			return unify(cp);
	}
	throw Error(s->loc,
			(format("type is `%1%' but should be `%2%'\n%3%: error: to match `%4%' here")
			 % s->str() % t->str() % t->loc.str() % t->str()).str());
}