/* Tuplr: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

#ifndef TUPLR_HPP
#define TUPLR_HPP

#include <stdarg.h>
#include <iostream>
#include <list>
#include <map>
#include <string>
#include <vector>
#include <boost/format.hpp>

#define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i)

using namespace std;
using boost::format;


/***************************************************************************
 * Basic Utility Classes                                                   *
 ***************************************************************************/

struct Cursor {
	Cursor(const string& n="", unsigned l=1, unsigned c=0) : name(n), line(l), col(c) {}
	operator bool() const { return !(line == 1 && col == 0); }
	string str() const { return (format("%1%:%2%:%3%") % name % line % col).str(); }
	string   name;
	unsigned line;
	unsigned col;
};

struct Error {
	Error(const string& m, Cursor c=Cursor()) : msg(m), loc(c) {}
	const string what() const throw() { return (loc ? loc.str() + ": " : "") + "error: " + msg; }
	string msg;
	Cursor loc;
};

struct Log {
	Log(ostream& o, ostream& e) : out(o), err(e) {}
	ostream& out;
	ostream& err;
};

template<typename Atom>
struct Exp {
	Exp(Cursor c)                : type(LIST), loc(c)          {}
	Exp(Cursor c, const Atom& a) : type(ATOM), loc(c), atom(a) {}
	typedef std::vector< Exp<Atom> > List;
	enum { ATOM, LIST } type;
	Cursor loc;
	Atom   atom;
	List   list;
};


/***************************************************************************
 * Lexer: Text (istream) -> S-Expressions (SExp) (Prefix S for Syntactic)  *
 ***************************************************************************/

typedef Exp<string> SExp; ///< Textual S-Expression

SExp readExpression(Cursor& cur, std::istream& in);


/***************************************************************************
 * Backend (Prefix C for Compiled)                                         *
 ***************************************************************************/

typedef void* CValue;    ///< Compiled value (opaque)
typedef void* CFunction; ///< Compiled function (opaque)
struct        CEngine;   ///< Backend data (opaque)


/***************************************************************************
 * Abstract Syntax Tree (Prefix A for Abstract)                            *
 ***************************************************************************/

struct TEnv; ///< Type-Time Environment
struct CEnv; ///< Compile-Time Environment

/// Base class for all AST nodes
struct AST {
	AST(Cursor c=Cursor()) : loc(c) {}
	virtual ~AST() {}
	virtual string str()                    const = 0;
	virtual bool   operator==(const AST& o) const = 0;
	virtual bool   contains(AST* child)     const { return false; }
	virtual void   constrain(TEnv& tenv)    const {}
	virtual void   lift(CEnv& cenv)               {}
	virtual CValue compile(CEnv& cenv)            = 0;
	Cursor loc;
};

/// Literal value
template<typename VT>
struct ASTLiteral : public AST {
	ASTLiteral(VT v, Cursor c) : AST(c), val(v) {}
	bool operator==(const AST& rhs) const {
		const ASTLiteral<VT>* r = dynamic_cast<const ASTLiteral<VT>*>(&rhs);
		return (r && (val == r->val));
	}
	string str() const { return (format("%1%") % val).str(); }
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
	const VT val;
};

/// Symbol, e.g. "a"
struct ASTSymbol : public AST {
	bool   operator==(const AST& rhs) const { return this == &rhs; }
	string str()                      const { return cppstr; }
	void   lift(CEnv& cenv);
	CValue compile(CEnv& cenv);
private:
	friend class PEnv;
	ASTSymbol(const string& s, Cursor c) : AST(c), cppstr(s) {}
	const string cppstr;
};

/// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)"
struct ASTTuple : public AST, public vector<AST*> {
	ASTTuple(const vector<AST*>& t=vector<AST*>(), Cursor c=Cursor()) : AST(c), vector<AST*>(t) {}
	ASTTuple(size_t size, Cursor c) : AST(c), vector<AST*>(size) {}
	ASTTuple(Cursor c, AST* ast, ...) : AST(c) {
		va_list args; va_start(args, ast);
		push_back(ast);
		for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
			push_back(a);
		va_end(args);
	}
	string str() const {
		string ret = "(";
		for (size_t i = 0; i != size(); ++i)
			ret += (at(i) ? at(i)->str() : "NULL") + ((i != size() - 1) ? " " : "");
		return ret + ")";
	}
	bool operator==(const AST& rhs) const {
		const ASTTuple* rt = dynamic_cast<const ASTTuple*>(&rhs);
		if (!rt || rt->size() != size()) return false;
		const_iterator l = begin();
		FOREACH(const_iterator, r, *rt)
			if (!(*(*l++) == *(*r)))
				return false;
		return true;
	}
	void lift(CEnv& cenv) {
		FOREACH(iterator, t, *this)
			(*t)->lift(cenv);
	}
	bool contains(AST* child) const {
		if (*this == *child) return true;
		FOREACH(const_iterator, p, *this)
			if (**p == *child || (*p)->contains(child))
				return true;
		return false;
	}
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv) { throw Error("tuple compiled"); }
};

/// Type Expression, e.g. "Int", "(Fn (Int Int) Float)"
struct AType : public ASTTuple {
	AType(unsigned i, Cursor c=Cursor()) : ASTTuple(0, c), kind(VAR), id(i) {}
	AType(ASTSymbol* s) : ASTTuple(0, s->loc), kind(PRIM), id(0) { push_back(s); }
	AType(const ASTTuple& t, Cursor c) : ASTTuple(t, c), kind(EXPR), id(0) {}
	AType(Cursor c, AST* ast, ...) : ASTTuple(0, c), kind(EXPR), id(0) {
		va_list args; va_start(args, ast);
		push_back(ast);
		for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
			push_back(a);
		va_end(args);
	}
	string str() const {
		switch (kind) {
		case VAR:  return (format("?%1%") % id).str();
		case PRIM: return at(0)->str();
		case EXPR: return ASTTuple::str();
		}
		return ""; // never reached
	}
	void   constrain(TEnv& tenv) const {}
	CValue compile(CEnv& cenv) { return NULL; }
	bool   var() const { return kind == VAR; }
	bool concrete() const {
		switch (kind) {
		case VAR:  return false;
		case PRIM: return true;
		case EXPR:
			FOREACH(const_iterator, t, *this) {
				AType* kid = dynamic_cast<AType*>(*t);
				if (kid && !kid->concrete())
					return false;
			}
		}
		return true;
	}
	bool operator==(const AST& rhs) const {
		const AType* rt = dynamic_cast<const AType*>(&rhs);
		if (!rt || kind != rt->kind)
			return false;
		else
			switch (kind) {
			case VAR:  return id == rt->id;
			case PRIM: return at(0)->str() == rt->at(0)->str();
			case EXPR: return ASTTuple::operator==(rhs);
			}
		return false; // never reached
	}
	enum { VAR, PRIM, EXPR } kind;
	unsigned                 id;
};

/// Lifted system functions (of various types) for a single Tuplr function
struct Funcs : public list< pair<AType*, CFunction> > {
	CFunction find(AType* type) const {
		for (const_iterator f = begin(); f != end(); ++f)
			if (*f->first == *type)
				return f->second;
		return NULL;
	}
};

/// Closure (first-class function with captured lexical bindings)
struct ASTClosure : public ASTTuple {
	ASTClosure(Cursor c, ASTTuple* p, AST* b, const string& n="")
		: ASTTuple(c, 0, p, b, NULL), name(n) {}
	bool      operator==(const AST& rhs) const { return this == &rhs; }
	string    str() const { return (format("%1%") % this).str(); }
	void      constrain(TEnv& tenv) const;
	void      lift(CEnv& cenv);
	CValue    compile(CEnv& cenv);
	ASTTuple* prot() const { return dynamic_cast<ASTTuple*>(at(1)); }
private:
	Funcs  funcs;
	string name;
};

/// Function call/application, e.g. "(func arg1 arg2)"
struct ASTCall : public ASTTuple {
	ASTCall(const SExp& e, const ASTTuple& t) : ASTTuple(t, e.loc), exp(e) {}
	void   constrain(TEnv& tenv) const;
	void   lift(CEnv& cenv);
	CValue compile(CEnv& cenv);
	const SExp& exp;
};

/// Definition special form, e.g. "(def x 2)"
struct ASTDefinition : public ASTCall {
	ASTDefinition(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
	void   constrain(TEnv& tenv) const;
	void   lift(CEnv& cenv);
	CValue compile(CEnv& cenv);
};

/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct ASTIf : public ASTCall {
	ASTIf(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
};

/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct ASTPrimitive : public ASTCall {
	ASTPrimitive(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
};

/// Cons special form, e.g. "(cons 1 2)"
struct ASTConsCall : public ASTCall {
	ASTConsCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
	AType* functionType(CEnv& cenv);
	void   constrain(TEnv& tenv) const;
	void   lift(CEnv& cenv);
	CValue compile(CEnv& cenv);
	static Funcs funcs;
};

/// Car special form, e.g. "(car p)"
struct ASTCarCall : public ASTCall {
	ASTCarCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
};

/// Cdr special form, e.g. "(cdr p)"
struct ASTCdrCall : public ASTCall {
	ASTCdrCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
};


/***************************************************************************
 * Parser: S-Expressions (SExp) -> AST Nodes (AST) (Prefix P for Parsing)  *
 ***************************************************************************/

/// Parse Time Environment (symbol table)
struct PEnv : private map<const string, ASTSymbol*> {
	typedef AST* (*PF)(PEnv&, const SExp&, void*); // Parse Function
	struct Handler { Handler(PF f, void* a=0) : func(f), arg(a) {} PF func; void* arg; };
	map<const string, Handler> aHandlers; ///< Atom parse functions
	map<const string, Handler> lHandlers; ///< List parse functions
	void reg(bool list, const string& s, const Handler& h) {
		(list ? lHandlers : aHandlers).insert(make_pair(sym(s)->str(), h));
	}
	const Handler* handler(bool list, const string& s) const {
		const map<const string, Handler>& handlers = list ? lHandlers : aHandlers;
		map<string, Handler>::const_iterator i = handlers.find(s);
		return (i != handlers.end()) ? &i->second : NULL;
	}
	ASTSymbol* sym(const string& s, Cursor c=Cursor()) {
		const const_iterator i = find(s);
		return ((i != end())
			? i->second
			: insert(make_pair(s, new ASTSymbol(s, c))).first->second);
	}
};

/// The fundamental parser method
static AST* parseExpression(PEnv& penv, const SExp& exp);

static ASTTuple
pmap(PEnv& penv, const SExp& e)
{
	assert(e.type == SExp::LIST);
	ASTTuple ret(e.list.size(), e.loc);
	size_t n = 0;
	FOREACH(SExp::List::const_iterator, i, e.list)
		ret[n++] = parseExpression(penv, *i);
	return ret;
}

static AST*
parseExpression(PEnv& penv, const SExp& exp)
{
	if (exp.type == SExp::LIST) {
		if (exp.list.empty()) throw Error("call to empty list", exp.loc);
		if (exp.list.front().type == SExp::ATOM) {
			const PEnv::Handler* handler = penv.handler(true, exp.list.front().atom);
			if (handler) // Dispatch to list parse function
				return handler->func(penv, exp, handler->arg);
		}
		return new ASTCall(exp, pmap(penv, exp)); // Parse as regular call
	} else if (isdigit(exp.atom[0])) {
		if (exp.atom.find('.') == string::npos)
			return new ASTLiteral<int32_t>(strtol(exp.atom.c_str(), NULL, 10), exp.loc);
		else
			return new ASTLiteral<float>(strtod(exp.atom.c_str(), NULL), exp.loc);
	} else {
		const PEnv::Handler* handler = penv.handler(false, exp.atom);
		if (handler) // Dispatch to atom parse function
			return handler->func(penv, exp, handler->arg);
	}
	return penv.sym(exp.atom, exp.loc);
}

template<typename C>
inline AST*
parseCall(PEnv& penv, const SExp& exp, void* arg)
{
	return new C(exp, pmap(penv, exp));
}

template<typename T>
inline AST*
parseLiteral(PEnv& penv, const SExp& exp, void* arg)
{
	return new ASTLiteral<T>(*reinterpret_cast<T*>(arg), exp.loc);
}

inline AST*
parseFn(PEnv& penv, const SExp& exp, void* arg)
{
	SExp::List::const_iterator a = exp.list.begin(); ++a;
	return new ASTClosure(exp.loc,
			new ASTTuple(pmap(penv, *a++)),
			parseExpression(penv, *a++));
}


/***************************************************************************
 * Generic Lexical Environment                                             *
 ***************************************************************************/

template<typename K, typename V>
struct Env : public list< map<K,V> > {
	typedef map<K,V> Frame;
	Env() : list<Frame>(1) {}
	void push() { list<Frame>::push_front(Frame()); }
	void pop()  { assert(!this->empty()); list<Frame>::pop_front(); }
	const V& def(const K& k, const V& v) {
		typename Frame::iterator existing = this->front().find(k);
		if (existing != this->front().end() && existing->second != v)
			throw Error("redefinition");
		return (this->front()[k] = v);
	}
	V* ref(const K& name) {
		typename Frame::iterator s;
		for (typename Env::iterator i = this->begin(); i != this->end(); ++i)
			if ((s = i ->find(name)) != i->end())
				return &s->second;
		return 0;
	}
};


/***************************************************************************
 * Typing (Prefix T for Type)                                              *
 ***************************************************************************/

struct TSubst : public map<AType*, AType*> {
	TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); }
};

/// Type-Time Environment
struct TEnv : public Env<const AST*,AType*> {
	TEnv(PEnv& p) : penv(p), varID(1) {}
	struct Constraint : public pair<AType*,AType*> {
		Constraint(AType* a, AType* b, Cursor c=Cursor()) : pair<AType*,AType*>(a, b), loc(c) {}
		Cursor loc;
	};
	typedef list<Constraint> Constraints;
	AType* var(Cursor c=Cursor()) { return new AType(varID++, c); }
	AType* type(const AST* ast) {
		AType** t = ref(ast);
		return t ? *t : def(ast, var(ast->loc));
	}
	AType* named(const string& name) {
		return *ref(penv.sym(name));
	}
	void constrain(const AST* o, AType* t) {
		assert(!dynamic_cast<const AType*>(o));
		constraints.push_back(Constraint(type(o), t, o->loc));
	}
	void          solve() { apply(unify(constraints)); }
	void          apply(const TSubst& substs);
	static TSubst unify(const Constraints& c);

	PEnv&       penv;
	Constraints constraints;
	unsigned    varID;
};


/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

struct CEnvPimpl;

/// Compile-Time Environment
struct CEnv {
	CEnv(PEnv& p, TEnv& t, CEngine& e, ostream& os=std::cout, ostream& es=std::cerr);
	~CEnv();
	
	typedef Env<const ASTSymbol*, AST*>   Code;
	typedef Env<const AST*, CValue> Vals;
	
	string gensym(const char* s="_") { return (format("%s%d") % s % symID++).str(); }
	void   push() { code.push(); vals.push(); }
	void   pop()  { code.pop();  vals.pop();  }
	void   precompile(AST* obj, CValue value) { vals.def(obj, value); }
	CValue compile(AST* obj);
	void   optimise(CFunction f);
	void   write(std::ostream& os);
	
	CEngine&  engine;
	PEnv&     penv;
	TEnv      tenv;
	Code      code;
	Vals      vals;
	unsigned  symID;
	CFunction alloc;
	Log       log;

private:
	CEnvPimpl* _pimpl;
};


/***************************************************************************
 * EVAL/REPL/MAIN                                                          *
 ***************************************************************************/

void  initTypes(PEnv& penv, TEnv& tenv);
void  initLang(PEnv& penv, TEnv& tenv);
CEnv* newCenv(PEnv& penv, TEnv& tenv);
int   eval(CEnv& cenv, const string& name, istream& is);
int   repl(CEnv& cenv);

#endif // TUPLR_HPP