/* Tuplr Type Inferencing
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "tuplr.hpp"

/***************************************************************************
 * AST Type Contraints                                                     *
 ***************************************************************************/

void
ASTTuple::constrain(TEnv& tenv) const
{
	AType* t = new AType(loc, 0);
	FOREACH(const_iterator, p, *this) {
		(*p)->constrain(tenv);
		t->push_back(tenv.type(*p));
	}
	tenv.constrain(this, t);
}

void
ASTClosure::constrain(TEnv& tenv) const
{
	at(1)->constrain(tenv);
	at(2)->constrain(tenv);
	AType* protT = tenv.type(at(1));
	AType* bodyT = tenv.type(at(2));
	tenv.constrain(this, new AType(loc, tenv.penv.sym("Fn"), protT, bodyT, 0));
}

void
ASTCall::constrain(TEnv& tenv) const
{
	FOREACH(const_iterator, p, *this)
		(*p)->constrain(tenv);
	AType* retT = tenv.type(this);
	AType* argsT = new AType(loc, 0);
	for (size_t i = 1; i < size(); ++i)
		argsT->push_back(tenv.type(at(i)));
	tenv.constrain(at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
}

void
ASTDefinition::constrain(TEnv& tenv) const
{
	if (size() != 3) throw Error("`def' requires exactly 2 arguments", exp.loc);
	if (!dynamic_cast<const ASTSymbol*>(at(1)))
		throw Error("`def' name is not a symbol", exp.loc);
	FOREACH(const_iterator, p, *this)
		(*p)->constrain(tenv);
	AType* tvar = tenv.type(this);
	tenv.constrain(at(1), tvar);
	tenv.constrain(at(2), tvar);
}

void
ASTIf::constrain(TEnv& tenv) const
{
	if (size() < 3)      throw Error("`if' requires exactly 3 arguments", exp.loc);
	if (size() % 2 != 0) throw Error("`if' missing final else clause", exp.loc);
	FOREACH(const_iterator, p, *this)
		(*p)->constrain(tenv);
	AType* retT = tenv.type(this);
	for (size_t i = 1; i < size(); i += 2) {
		if (i == size() - 1) {
			tenv.constrain(at(i), retT);
		} else {
			tenv.constrain(at(i), tenv.named("Bool"));
			tenv.constrain(at(i+1), retT);
		}
	}
}

void
ASTPrimitive::constrain(TEnv& tenv) const
{
	const string n = dynamic_cast<ASTSymbol*>(at(0))->str();
	enum { ARITHMETIC, BINARY, BITWISE, COMPARISON } type;
	if (n == "+" || n == "-" || n == "*" || n == "/")
		type = ARITHMETIC;
	else if (n == "%")
		type = BINARY;
	else if (n == "and" || n == "or" || n == "xor")
		type = BITWISE;
	else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=")
		type = COMPARISON;
	else
		throw Error((format("unknown primitive `%1%'") % n).str(), exp.loc);

	FOREACH(const_iterator, p, *this)
		(*p)->constrain(tenv);

	switch (type) {
	case ARITHMETIC:
		if (size() < 3)
			throw Error((format("`%1%' requires at least 2 arguments") % n).str(), exp.loc);
		for (size_t i = 1; i < size(); ++i)
			tenv.constrain(at(i), tenv.type(this));
		break;
	case BINARY:
		if (size() != 3)
			throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc);
		tenv.constrain(at(1), tenv.type(this));
		tenv.constrain(at(2), tenv.type(this));
		break;
	case BITWISE:
		if (size() != 3)
			throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc);
		tenv.constrain(this,  tenv.named("Bool"));
		tenv.constrain(at(1), tenv.named("Bool"));
		tenv.constrain(at(2), tenv.named("Bool"));
		break;
	case COMPARISON:
		if (size() != 3)
			throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc);
		tenv.constrain(this, tenv.named("Bool"));
		tenv.constrain(at(1), tenv.type(at(2)));
		break;
	default:
		throw Error((format("unknown primitive `%1%'") % n).str(), exp.loc);
	}
}

void
ASTConsCall::constrain(TEnv& tenv) const
{
	if (size() != 3) throw Error("`cons' requires exactly 2 arguments", loc);
	AType* t = new AType(loc, tenv.penv.sym("Pair"), 0);
	for (size_t i = 1; i < size(); ++i) {
		at(i)->constrain(tenv);
		t->push_back(tenv.type(at(i)));
	}
	tenv.constrain(this, t);
}

void
ASTCarCall::constrain(TEnv& tenv) const
{
	if (size() != 2) throw Error("`car' requires exactly 1 argument", loc);
	at(1)->constrain(tenv);
	AType* carT  = tenv.var(loc);
	AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), carT, tenv.var(), 0);
	tenv.constrain(at(1), pairT);
	tenv.constrain(this, carT);
}

void
ASTCdrCall::constrain(TEnv& tenv) const
{
	if (size() != 2) throw Error("`cdr' requires exactly 1 argument", loc);
	at(1)->constrain(tenv);
	AType* cdrT  = tenv.var(loc);
	AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), tenv.var(), cdrT, 0);
	tenv.constrain(at(1), pairT);
	tenv.constrain(this, cdrT);
}


/***************************************************************************
 * Type Inferencing/Substitution                                           *
 ***************************************************************************/

static void
substitute(ASTTuple* tup, AST* from, AST* to)
{
	if (!tup) return;
	for (size_t i = 0; i < tup->size(); ++i)
		if (*tup->at(i) == *from)
			tup->at(i) = to;
		else
			substitute(dynamic_cast<ASTTuple*>(tup->at(i)), from, to);
}

TSubst
compose(const TSubst& delta, const TSubst& gamma) // TAPL 22.1.1
{
	TSubst r;
	for (TSubst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
		TSubst::const_iterator d = delta.find(g->second);
		r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second));
	}
	for (TSubst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
		if (gamma.find(d->first) == gamma.end())
			r.insert(*d);
	}
	return r;
}

void
substConstraints(TEnv::Constraints& constraints, AType* s, AType* t)
{
	for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) {
		TEnv::Constraints::iterator next = c; ++next;
		if (*c->first  == *s) c->first  = t;
		if (*c->second == *s) c->second = t;
		substitute(c->first, s, t);
		substitute(c->second, s, t);
		c = next;
	}
}

TSubst
TEnv::unify(const Constraints& constraints) // TAPL 22.4
{
	if (constraints.empty()) return TSubst();
	AType*      s  = constraints.begin()->first;
	AType*      t  = constraints.begin()->second;
	Constraints cp = constraints;
	cp.erase(cp.begin());

	if (*s == *t) {
		return unify(cp);
	} else if (s->var() && !t->contains(s)) {
		substConstraints(cp, s, t);
		return compose(unify(cp), TSubst(s, t));
	} else if (t->var() && !s->contains(t)) {
		substConstraints(cp, t, s);
		return compose(unify(cp), TSubst(t, s));
	} else if (s->kind == AType::EXPR && s->kind == t->kind && s->size() == t->size()) {
		for (size_t i = 0; i < s->size(); ++i) {
			AType* si = dynamic_cast<AType*>(s->at(i));
			AType* ti = dynamic_cast<AType*>(t->at(i));
			if (si && ti)
				cp.push_back(Constraint(si, ti, si->loc));
		}
		return unify(cp);
	} else {
		throw Error((format("type is `%1%' but should be `%2%'") % s->str() % t->str()).str(),
				constraints.begin()->loc);
	}
}

void
TEnv::apply(const TSubst& substs)
{
	FOREACH(TSubst::const_iterator, s, substs)
		FOREACH(Frame::iterator, t, front())
			if (*t->second == *s->first)
				t->second = s->second;
			else
				substitute(t->second, s->first, s->second);
}