/* Tuplr: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <cerrno>
#include <cstring>
#include <fstream>
#include <set>
#include <sstream>
#include <stack>
#include "tuplr.hpp"

using namespace std;
using boost::format;

GC Object::pool;

template<typename Atom>
ostream&
operator<<(ostream& out, const Exp<Atom>& exp)
{
	switch (exp.type) {
	case Exp<Atom>::ATOM:
		out << exp.atom;
		break;
	case Exp<Atom>::LIST:
		out << "(";
		for (size_t i = 0; i != exp.size(); ++i)
			out << exp.at(i) << ((i != exp.size() - 1) ? " " : "");
		out << ")";
		break;
	}
	return out;
}


/***************************************************************************
 * Lexer                                                                   *
 ***************************************************************************/

inline int
readChar(Cursor& cur, istream& in)
{
	int ch = in.get();
	switch (ch) {
	case '\n': ++cur.line; cur.col = 0; break;
	default:   ++cur.col;
	}
	return ch;
}

SExp
readExpression(Cursor& cur, istream& in)
{
#define PUSH(s, t)  { if (t != "") { s.top().push_back(SExp(loc, t)); t = ""; } }
#define YIELD(s, t) { if (s.empty()) { return SExp(loc, t); } else PUSH(s, t) }
	stack<SExp> stk;
	string      tok;
	Cursor      loc; // start of tok
	while (int c = readChar(cur, in)) {
		switch (c) {
		case EOF:
			THROW_IF(!stk.empty(), cur, "unexpected end of file")
			return SExp(cur);
		case ';':
			while ((c = readChar(cur, in)) != '\n') {}
		case '\n': case ' ': case '\t':
			if (tok != "") YIELD(stk, tok);
			break;
		case '"':
			loc = cur;
			do { tok.push_back(c); } while ((c = readChar(cur, in)) != '"');
			YIELD(stk, tok + '"');
			break;
		case '(':
			stk.push(SExp(cur));
			break;
		case ')':
			switch (stk.size()) {
			case 0:
				throw Error(cur, "unexpected `)'");
			case 1:
				PUSH(stk, tok);
				return stk.top();
			default:
				PUSH(stk, tok);
				SExp l = stk.top();
				stk.pop();
				stk.top().push_back(l);
			}
			break;
		case '#':
			if (in.peek() == '|') {
				while (!(readChar(cur, in) == '|' && readChar(cur, in) == '#')) {}
				break;
			}
		default:
			if (tok == "") loc = cur;
			tok += c;
		}
	}
	switch (stk.size()) {
	case 0:  return SExp(loc, tok);
	case 1:  return stk.top();
	default: throw  Error(cur, "missing `)'");
	}
	return SExp(cur);
}


/***************************************************************************
 * Macro Functions                                                         *
 ***************************************************************************/

inline SExp
macDef(PEnv& penv, const SExp& exp)
{
	THROW_IF(exp.size() < 3, exp.loc, "[MAC] `def' requires at least 2 arguments")
	if (exp.at(1).type == SExp::ATOM) {
		return exp;
	} else {
		// (def (f x) y) => (def f (fn (x) y))
		SExp argsExp(exp.loc);
		for (size_t i = 1; i < exp.at(1).size(); ++i)
			argsExp.push_back(exp.at(1).at(i));
		SExp fnExp(exp.at(2).loc);
		fnExp.push_back(SExp(exp.at(2).loc, "fn"));
		fnExp.push_back(argsExp);
		for (size_t i = 2; i < exp.size(); ++i)
			fnExp.push_back(exp.at(i));
		SExp ret(exp.loc);
		ret.push_back(exp.at(0));
		ret.push_back(exp.at(1).at(0));
		ret.push_back(fnExp);
		return ret;
	}
}


/***************************************************************************
 * Parser Functions                                                        *
 ***************************************************************************/

template<typename C>
inline AST*
parseCall(PEnv& penv, const SExp& exp, void* arg)
{
	return new C(exp, penv.parseTuple(exp));
}

template<typename T>
inline AST*
parseLiteral(PEnv& penv, const SExp& exp, void* arg)
{
	return new ALiteral<T>(*reinterpret_cast<T*>(arg), exp.loc);
}

inline AST*
parseFn(PEnv& penv, const SExp& exp, void* arg)
{
	if (exp.size() < 2)
		throw Error(exp.loc, "Missing function parameters and body");
	else if (exp.size() < 3)
		throw Error(exp.loc, "Missing function body");
	SExp::const_iterator a = exp.begin(); ++a;
	AFn* ret = tup<AFn>(exp.loc, penv.sym("fn"), new ATuple(penv.parseTuple(*a++)), 0);
	while (a != exp.end())
		ret->push_back(penv.parse(*a++));
	return ret;
}


/***************************************************************************
 * Standard Definitions                                                    *
 ***************************************************************************/

void
initLang(PEnv& penv, TEnv& tenv)
{
	// Types
	tenv.def(penv.sym("Nothing"), make_pair((AST*)0, new AType(penv.sym("Nothing"))));
	tenv.def(penv.sym("Bool"),    make_pair((AST*)0, new AType(penv.sym("Bool"))));
	tenv.def(penv.sym("Int"),     make_pair((AST*)0, new AType(penv.sym("Int"))));
	tenv.def(penv.sym("Float"),   make_pair((AST*)0, new AType(penv.sym("Float"))));

	// Literals
	static bool trueVal  = true;
	static bool falseVal = false;
	penv.reg(false, "#t", PEnv::Handler(parseLiteral<bool>, &trueVal));
	penv.reg(false, "#f", PEnv::Handler(parseLiteral<bool>, &falseVal));

	// Macros
	penv.defmac("def",  macDef);

	// Special forms
	penv.reg(true, "fn",   PEnv::Handler(parseFn));
	penv.reg(true, "if",   PEnv::Handler(parseCall<AIf>));
	penv.reg(true, "def",  PEnv::Handler(parseCall<ADef>));

	// Numeric primitives
	penv.reg(true, "+",   PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "-",   PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "*",   PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "/",   PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "%",   PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "and", PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "or",  PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "xor", PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "=",   PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "!=",  PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, ">",   PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, ">=",  PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "<",   PEnv::Handler(parseCall<APrimitive>));
	penv.reg(true, "<=",  PEnv::Handler(parseCall<APrimitive>));
}


/***************************************************************************
 * EVAL/REPL                                                               *
 ***************************************************************************/

int
eval(CEnv& cenv, const string& name, istream& is)
{
	AST*   result     = NULL;
	AType* resultType = NULL;
	list< pair<SExp, AST*> > exprs;
	Cursor cursor(name);
	try {
		while (true) {
			SExp exp = readExpression(cursor, is);
			if (exp.type == SExp::LIST && exp.empty())
				break;

			result = cenv.penv.parse(exp); // Parse input
			Constraints c;
			result->constrain(cenv.tenv, c); // Constrain types
			cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints
			resultType = cenv.type(result);
			result->lift(cenv); // Lift functions
			exprs.push_back(make_pair(exp, result));

			// Add definitions as GC roots
			if (result->to<ADef*>())
				cenv.lock(result);

			// Add types in type substition as GC roots
			for (Subst::iterator i = cenv.tsubst.begin(); i != cenv.tsubst.end(); ++i) {
				Object::pool.addRoot(i->first);
				Object::pool.addRoot(i->second);
			}
		}

		// Print CPS form
		CValue val = NULL;
		/*for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) {
			cout << "; CPS" << endl;
			pprint(cout, i->second->cps(cenv.tenv, cenv.penv.sym("cont")));
		}*/

		if (resultType->concrete()) {
			// Create function for top-level of program
			CFunction f = cenv.engine()->startFunction(cenv, "main", resultType, ATuple(cursor));

			// Compile all expressions into it
			for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i)
				val = cenv.compile(i->second);

			// Finish and call it
			cenv.engine()->finishFunction(cenv, f, resultType, val);
			cenv.out << cenv.engine()->call(cenv, f, resultType);
		}
		cenv.out << " : " << resultType << endl;

		Object::pool.collect(Object::pool.roots());

		if (cenv.args.find("-d") != cenv.args.end())
			cenv.engine()->writeModule(cenv, cenv.out);

	} catch (Error& e) {
		cenv.err << e.what() << endl;
		return 1;
	}
	return 0;
}

int
repl(CEnv& cenv)
{
	while (1) {
		cenv.out << "() ";
		cenv.out.flush();
		Cursor cursor("(stdin)");

		try {
			SExp exp = readExpression(cursor, std::cin);
			if (exp.type == SExp::LIST && exp.empty())
				break;

			AST* body = cenv.penv.parse(exp); // Parse input
			Constraints c;
			body->constrain(cenv.tenv, c); // Constrain types

			Subst oldSubst = cenv.tsubst;
			cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints

			AType* bodyT = cenv.type(body);
			THROW_IF(!bodyT, cursor, "call to untyped body")

			body->lift(cenv);

			CFunction f = NULL;
			try {
				// Create anonymous function to insert code into
				f = cenv.engine()->startFunction(cenv, cenv.penv.gensymstr("_repl"), bodyT, ATuple(cursor));
				CValue retVal = cenv.compile(body);
				cenv.engine()->finishFunction(cenv, f, bodyT, retVal);
				cenv.out << cenv.engine()->call(cenv, f, bodyT);
			} catch (Error& e) {
				ADef* def = body->to<ADef*>();
				if (def)
					cenv.out << def->sym();
				else
					cenv.out << "?";
				cenv.engine()->eraseFunction(cenv, f);
			}
			cenv.out << " : " << cenv.type(body) << endl;

			// Add definitions as GC roots
			if (body->to<ADef*>())
				cenv.lock(body);

			Object::pool.collect(Object::pool.roots());

			cenv.tsubst = oldSubst;
			if (cenv.args.find("-d") != cenv.args.end())
				cenv.engine()->writeModule(cenv, cenv.out);

		} catch (Error& e) {
			cenv.err << e.what() << endl;
		}
	}
	return 0;
}


/***************************************************************************
 * MAIN                                                                    *
 ***************************************************************************/

int
print_usage(char* name, bool error)
{
	ostream& os = error ? cerr : cout;
	os << "Usage: " << name << " [OPTION]... [FILE]..."  << endl;
	os << "Evaluate and/or compile Tuplr code"           << endl;
	os << endl;
	os << "  -h               Display this help and exit"        << endl;
	os << "  -r               Enter REPL after evaluating files" << endl;
	os << "  -p               Pretty-print input only"           << endl;
	os << "  -g               Debug (disable optimisation)"      << endl;
	os << "  -d               Dump assembly output"              << endl;
	os << "  -e EXPRESSION    Evaluate EXPRESSION"               << endl;
	os << "  -o FILE          Write output to FILE"              << endl;
	return error ? 1 : 0;
}

int
main(int argc, char** argv)
{
	PEnv penv;
	TEnv tenv(penv);
	initLang(penv, tenv);

	Engine* engine = tuplr_new_engine();
	CEnv*   cenv   = new CEnv(penv, tenv, engine);

	cenv->push();
	Object::pool.lock();

	map<string,string> args;
	list<string>       files;
	for (int i = 1; i < argc; ++i) {
		if (!strncmp(argv[i], "-h", 3)) {
			return print_usage(argv[0], false);
		} else if (argv[i][0] != '-') {
			files.push_back(argv[i]);
		} else if (!strncmp(argv[i], "-r", 3)
				|| !strncmp(argv[i], "-p", 3)
				|| !strncmp(argv[i], "-g", 3)
				|| !strncmp(argv[i], "-d", 3)) {
			args.insert(make_pair(argv[i], ""));
		} else if (i == argc-1 || argv[i+1][0] == '-') {
			return print_usage(argv[0], true);
		} else {
			args.insert(make_pair(argv[i], argv[i+1]));
			++i;
		}
	}

	cenv->args = args;

	int ret = 0;

	string output;
	map<string,string>::const_iterator a = args.find("-o");
	if (a != args.end())
		output = a->second;

	a = args.find("-p");
	if (a != args.end()) {
		ifstream is(files.front().c_str());
		if (is.good()) {
			Cursor loc;
			SExp exp = readExpression(loc, is);
			AST* ast = penv.parse(exp);
			pprint(cout, ast);
		}
		return 0;
	}

	a = args.find("-e");
	if (a != args.end()) {
		istringstream is(a->second);
		ret = eval(*cenv, "(command line)", is);
	}

	for (list<string>::iterator f = files.begin(); f != files.end(); ++f) {
		ifstream is(f->c_str());
		if (is.good()) {
			ret = ret | eval(*cenv, *f, is);
		} else {
			cerr << argv[0] << ": " << *f << ": " << strerror(errno) << endl;
			++ret;
		}
		is.close();
	}

	if (args.find("-r") != args.end() || (files.empty() && args.find("-e") == args.end()))
		ret = repl(*cenv);

	if (output != "") {
		ofstream os(output.c_str());
		if (os.good()) {
			cenv->engine()->writeModule(*cenv, os);
		} else {
			cerr << argv[0] << ": " << a->second << ": " << strerror(errno) << endl;
			++ret;
		}
		os.close();
	}

	delete cenv;
	tuplr_free_engine(engine);

	return ret;
}