/* Tuplr: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <cerrno>
#include <cstring>
#include <fstream>
#include <set>
#include <sstream>
#include <stack>
#include "tuplr.hpp"

using namespace std;
using boost::format;

Funcs ASTConsCall::funcs;

std::ostream& err = std::cerr;
std::ostream& out = std::cout;


/***************************************************************************
 * S-Expression Lexer :: text -> S-Expressions (SExp)                      *
 ***************************************************************************/

SExp
readExpression(Cursor& cur, std::istream& in)
{
#define PUSH(s, t)  { if (t != "") { s.top().list.push_back(SExp(loc, t)); t = ""; } }
#define YIELD(s, t) { if (s.empty()) { return SExp(loc, t); } else PUSH(s, t) }
	stack<SExp> stk;
	string      tok;
	Cursor      loc; // start of tok
	while (char ch = in.get()) {
		++cur.col;
		switch (ch) {
		case EOF:
			if (!stk.empty()) throw Error("unexpected end of file", cur);
			return SExp(cur);
		case ';':
			while ((ch = in.get()) != '\n') {}
		case '\n':
			++cur.line; cur.col = 0;
		case ' ': case '\t':
			if (tok != "") YIELD(stk, tok);
			break;
		case '"':
			loc = cur;
			do { tok.push_back(ch); ++cur.col; } while ((ch = in.get()) != '"');
			YIELD(stk, tok + '"');
			break;
		case '(':
			stk.push(SExp(cur));
			break;
		case ')':
			switch (stk.size()) {
			case 0:
				throw Error("unexpected `)'", cur);
			case 1:
				PUSH(stk, tok);
				return stk.top();
			default:
				PUSH(stk, tok);
				SExp l = stk.top();
				stk.pop();
				stk.top().list.push_back(l);
			}
			break;
		default:
			if (tok == "") loc = cur;
			tok += ch;
		}
	}
	switch (stk.size()) {
	case 0:  return SExp(loc, tok);
	case 1:  return stk.top();
	default: throw  Error("missing `)'", cur);
	}
	return SExp(cur);
}


/***************************************************************************
 * Standard Definitions                                                    *
 ***************************************************************************/

void
initLang(PEnv& penv, TEnv& tenv)
{
	// Types
	tenv.def(penv.sym("Bool"),  new AType(penv.sym("Bool")));
	tenv.def(penv.sym("Int"),   new AType(penv.sym("Int")));
	tenv.def(penv.sym("Float"), new AType(penv.sym("Float")));

	// Literals
	static bool trueVal  = true;
	static bool falseVal = false;
	penv.reg(false, "#t",  PEnv::Handler(parseLiteral<bool>, &trueVal));
	penv.reg(false, "#f", PEnv::Handler(parseLiteral<bool>, &falseVal));

	// Special forms
	penv.reg(true, "fn",   PEnv::Handler(parseFn));
	penv.reg(true, "if",   PEnv::Handler(parseCall<ASTIf>));
	penv.reg(true, "def",  PEnv::Handler(parseCall<ASTDefinition>));
	penv.reg(true, "cons", PEnv::Handler(parseCall<ASTConsCall>));
	penv.reg(true, "car",  PEnv::Handler(parseCall<ASTCarCall>));
	penv.reg(true, "cdr",  PEnv::Handler(parseCall<ASTCdrCall>));
	
	// Numeric primitives
	penv.reg(true, "+",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "-",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "*",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "/",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "%",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "&",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "|",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "^",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "=",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "!=", PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, ">",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, ">=", PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "<",  PEnv::Handler(parseCall<ASTPrimitive>));
	penv.reg(true, "<=", PEnv::Handler(parseCall<ASTPrimitive>));
}


/***************************************************************************
 * EVAL/REPL/MAIN                                                          *
 ***************************************************************************/

int
print_usage(char* name, bool error)
{
	std::ostream& os = error ? std::cerr : std::cout;
	os << "Usage: " << name << " [OPTION]... [FILE]..."  << endl;
	os << "Evaluate and/or compile Tuplr code" << endl;
	os << endl;
	os << "  -h               Display this help and exit"        << endl;
	os << "  -r               Enter REPL after evaluating files" << endl;
	os << "  -e EXPRESSION    Evaluate EXPRESSION"               << endl;
	os << "  -o FILE          Write assembly output to FILE"     << endl;
	return error ? 1 : 0;
}

int
main(int argc, char** argv)
{
	PEnv penv;
	TEnv tenv(penv);
	initLang(penv, tenv);

	CEnv* cenv = newCenv(penv, tenv);

	map<string,string> args;
	list<string>       files;
	for (int i = 1; i < argc; ++i) {
		if (!strncmp(argv[i], "-h", 3)) {
			return print_usage(argv[0], false);
		} else if (argv[i][0] != '-') {
			files.push_back(argv[i]);
		} else if (!strncmp(argv[i], "-r", 3)) {
			args.insert(make_pair(argv[i], ""));
		} else if (i == argc-1 || argv[i+1][0] == '-') {
			return print_usage(argv[0], true);
		} else {
			args.insert(make_pair(argv[i], argv[i+1]));
			++i;
		}
	}

	int ret = 0;
	map<string,string>::iterator a = args.find("-e");
	if (a != args.end()) {
		istringstream is(a->second);
		ret = eval(*cenv, "(command line)", is);
	}
	
	for (list<string>::iterator f = files.begin(); f != files.end(); ++f) {
		ifstream is(f->c_str());
		if (is.good()) {
			ret = ret | eval(*cenv, *f, is);
		} else {
			std::cerr << argv[0] << ": " << *f << ": " << strerror(errno) << endl;
			++ret;
		}
		is.close();
	}

	if (args.find("-r") != args.end() || (files.empty() && args.find("-e") == args.end()))
		ret = repl(*cenv);
	
	a = args.find("-o");
	if (a != args.end()) {
		std::ofstream os(a->second.c_str());
		if (os.good()) {
			cenv->write(os);
		} else {
			cerr << argv[0] << ": " << a->second << ": " << strerror(errno) << endl;
			++ret;
		}
		os.close();
	}

	delete cenv;
	return ret;
}