/* Resp: A programming language
 * Copyright (C) 2008-2009 David Robillard <http://drobilla.net>
 *
 * Resp is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Resp is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Resp.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Unify type constraints
 */

#include <set>
#include "resp.hpp"

/** Build a type substitution for calling a generic function type
 * with a specific set of argument types
 */
Subst
TEnv::buildSubst(const AST* genericT, const AST& argsT)
{
	Subst subst;

	// Build substitution to apply to generic type
	const ATuple* genericProtT = genericT->as_tuple()->list_ref(1)->as_tuple();
	ATuple::const_iterator g = genericProtT->begin();
	ATuple::const_iterator a = argsT.as_tuple()->begin();
	for (; a != argsT.as_tuple()->end(); ++a, ++g) {
		if (AType::is_expr(*a)) {
			assert(AType::is_expr(*g));
			ATuple::const_iterator gi = (*g)->as_tuple()->begin();
			ATuple::const_iterator ci = (*a)->as_tuple()->begin();
			for (; gi != (*g)->as_tuple()->end(); ++gi, ++ci) {
				if ((*gi) && (*ci))
					subst.add(*gi, *ci);
			}
		} else {
			subst.add(*g, *a);
		}
	}

	return subst;
}

void
Constraints::constrain(TEnv& tenv, const AST* o, const AST* t)
{
	assert(o);
	assert(t);
	push_back(Constraint(tenv.var(o), t));
}

static const AST*
substitute(const AST* in, const AST* from, const AST* to)
{
	if (in == from)
		return to;

	const ATuple* tup = in->to_tuple();
	if (!tup || !tup->fst())
		return from;
	
	List ret;
	for (const auto& i : *tup->as_tuple()) {
		if (!i) {
			continue;
		} if (*i == *from) {
			ret.push_back(to); // FIXME: should be a copy w/ (*i)->loc
		} else if (i != to) {
			if (AType::is_expr(i))
				ret.push_back(substitute(i, from, to));
			else
				ret.push_back(i);
		} else {
			ret.push_back(i);
		}
	}
	return ret.head;
}

void
Subst::augment(const Subst& subst)
{
	for (auto s : subst) {
		if (!contains(s.first)) {
			add(s.first, s.second);
		}
	}
}

/// Compose two substitutions (TAPL 22.1.1)
Subst
Subst::compose(const Subst& delta, const Subst& gamma)
{
	Subst r;
	for (const auto& g : gamma) {
		Subst::const_iterator d = delta.find(g.second);
		r.add(g.first, ((d != delta.end()) ? *d : g).second);
	}
	for (const auto& d : delta) {
		if (gamma.find(d.first) == gamma.end())
			r.add(d.first, d.second);
	}
	return r;
}

/// Replace all occurrences of @a s with @a t
Constraints&
Constraints::replace(const AST* s, const AST* t)
{
	for (auto& c : *this) {
		if (*c.first == *s) {
			c.first = t; // FIXME: should be copy w/ c.first->loc;
		} else if (AType::is_expr(c.first)) {
			c.first = substitute(c.first, s, t);
		}
		if (*c.second == *s) {
			c.second = t; // FIXME: should be copy w/ c.second->loc;
		} else if (AType::is_expr(c.second) && c.second->to_tuple()->fst()) {
			c.second = substitute(c.second, s, t);
		}
	}
	return *this;
}

/// Unify a type constraint set (TAPL 22.4)
Subst
unify(const Constraints& constraints)
{
	if (constraints.empty())
		return Subst();

	Constraints::const_iterator i  = constraints.begin();
	const AST*                  s  = i->first;
	const AST*                  t  = i->second;
	Constraints                 cp(++i, constraints.end());

	if (*s == *t) {
		return unify(cp);
	} else if (AType::is_var(s) && !list_contains(t->to_tuple(), s)) {
		return Subst::compose(unify(cp.replace(s, t)), Subst(s, t));
	} else if (AType::is_var(t) && !list_contains(s->to_tuple(), t)) {
		return Subst::compose(unify(cp.replace(t, s)), Subst(t, s));
	} else if (AType::is_expr(s) && AType::is_expr(t)) {
		const ATuple* const st = s->as_tuple();
		const ATuple* const tt = t->as_tuple();
		ATuple::const_iterator si = st->begin();
		ATuple::const_iterator ti = tt->begin();
		for (; si != st->end() && ti != tt->end(); ++si, ++ti) {
			if (is_dots(*si) || is_dots(*ti))
				return unify(cp);
			else if (*si && *ti)
				cp.push_back(Constraint(*si, *ti));
			else
				throw Error(Cursor(), "match with missing list element");
		}
		if ((si == st->end() && ti == tt->end())
		    || (si != st->end() && is_dots(*si))
		    || (ti != tt->end() && is_dots(*ti)))
			return unify(cp);
	}
	throw Error(s->loc,
			(format("type is `%1%' but should be `%2%'\n%3%: error: to match `%4%' here")
			 % s->str() % t->str() % t->loc.str() % t->str()).str());
}