/* Tuplr: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

#ifndef TUPLR_HPP
#define TUPLR_HPP

#include <stdarg.h>
#include <iostream>
#include <list>
#include <map>
#include <string>
#include <vector>
#include <boost/format.hpp>

typedef void*       CValue;    ///< Compiled value (opaque)
typedef const void* CType;     ///< Compiled type (opaque)
typedef void*       CFunction; ///< Compiled function (opaque)

struct CEngine; ///< Backend data (opaque)
struct CArg;    ///< Parser function argument (opaque)

#define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i)

using namespace std;
using boost::format;

extern std::ostream& err;
extern std::ostream& out;

struct Cursor {
	Cursor(const string& n="", unsigned l=1, unsigned c=0) : name(n), line(l), col(c) {}
	string str() const { return (format("%1%:%2%:%3%") % name % line % col).str(); }
	string   name;
	unsigned line;
	unsigned col;
};

struct Error {
	Error(const string& m, Cursor c=Cursor()) : msg(m), loc(c) {}
	const string what() const throw() { return loc.str() + ": error: " + msg; }
	string msg;
	Cursor loc;
};

template<typename Atom>
struct Exp { // ::= Atom | (Exp*)
	Exp(Cursor c)                : type(LIST), loc(c)          {}
	Exp(Cursor c, const Atom& a) : type(ATOM), loc(c), atom(a) {}
	typedef std::vector< Exp<Atom> > List;
	enum { ATOM, LIST } type;
	Cursor loc;
	Atom   atom;
	List   list;
};


/***************************************************************************
 * S-Expression Lexer :: text -> S-Expressions (SExp)                      *
 ***************************************************************************/

typedef Exp<string> SExp; ///< Textual S-Expression

SExp readExpression(Cursor& cur, std::istream& in);


/***************************************************************************
 * Abstract Syntax Tree                                                    *
 ***************************************************************************/

struct TEnv; ///< Type-Time Environment
struct CEnv; ///< Compile-Time Environment

/// Base class for all AST nodes
struct AST {
	virtual ~AST() {}
	virtual string str()                    const = 0;
	virtual bool   operator==(const AST& o) const = 0;
	virtual bool   contains(AST* child)     const { return false; }
	virtual void   constrain(TEnv& tenv)    const {}
	virtual void   lift(CEnv& cenv)               {}
	virtual CValue compile(CEnv& cenv)            = 0;
};

/// Literal value
template<typename VT>
struct ASTLiteral : public AST {
	ASTLiteral(VT v) : val(v) {}
	bool operator==(const AST& rhs) const {
		const ASTLiteral<VT>* r = dynamic_cast<const ASTLiteral<VT>*>(&rhs);
		return (r && (val == r->val));
	}
	string str() const { return (format("%1%") % val).str(); }
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
	const VT val;
};

/// Symbol, e.g. "a"
struct ASTSymbol : public AST {
	ASTSymbol(const string& s, Cursor c=Cursor()) : loc(c), cppstr(s) {}
	bool   operator==(const AST& rhs) const { return this == &rhs; }
	string str()                      const { return cppstr; }
	CValue compile(CEnv& cenv);
private:
	Cursor       loc;
	const string cppstr;
};

/// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)"
struct ASTTuple : public AST, public vector<AST*> {
	ASTTuple(const vector<AST*>& t=vector<AST*>()) : vector<AST*>(t) {}
	ASTTuple(size_t size) : vector<AST*>(size) {}
	ASTTuple(AST* ast, ...) {
		push_back(ast);
		va_list args;
		va_start(args, ast);
		for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
			push_back(a);
		va_end(args);
	}
	string str() const {
		string ret = "(";
		for (size_t i = 0; i != size(); ++i)
			ret += at(i)->str() + ((i != size() - 1) ? " " : "");
		return ret + ")";
	}
	bool operator==(const AST& rhs) const {
		const ASTTuple* rt = dynamic_cast<const ASTTuple*>(&rhs);
		if (!rt || rt->size() != size()) return false;
		const_iterator l = begin();
		FOREACH(const_iterator, r, *rt)
			if (!(*(*l++) == *(*r)))
				return false;
		return true;
	}
	void lift(CEnv& cenv) {
		FOREACH(iterator, t, *this)
			(*t)->lift(cenv);
	}
	bool   contains(AST* child) const;
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv) { throw Error("tuple compiled"); }
};

/// Type Expression, e.g. "Int", "(Fn (Int Int) Float)"
struct AType : public ASTTuple {
	AType(const ASTTuple& t) : ASTTuple(t), kind(EXPR), ctype(0) {}
	AType(unsigned i) : kind(VAR), id(i), ctype(0) {}
	AType(ASTSymbol* n, CType t) : kind(PRIM), ctype(t) { push_back(n); }
	string str() const {
		switch (kind) {
		case VAR:  return (format("?%1%") % id).str();
		case PRIM: return at(0)->str();
		case EXPR: return ASTTuple::str();
		}
		return ""; // never reached
	}
	void   constrain(TEnv& tenv) const {}
	CValue compile(CEnv& cenv) { return NULL; }
	bool   var() const { return kind == VAR; }
	bool concrete() const {
		switch (kind) {
		case VAR:  return false;
		case PRIM: return true;
		case EXPR:
			FOREACH(const_iterator, t, *this) {
				AType* kid = dynamic_cast<AType*>(*t);
				if (kid && !kid->concrete())
					return false;
			}
		}
		return true;
	}
	bool operator==(const AST& rhs) const {
		const AType* rt = dynamic_cast<const AType*>(&rhs);
		if (!rt || kind != rt->kind)
			return false;
		else
			switch (kind) {
			case VAR:  return id == rt->id;
			case PRIM: return at(0)->str() == rt->at(0)->str();
			case EXPR: return ASTTuple::operator==(rhs);
			}
		return false; // never reached
	}
	CType type();
	enum Kind { VAR, PRIM, EXPR };
	Kind     kind;
	unsigned id;
private:
	const CType ctype;
};

/// Lifted system functions (of various types) for a single Tuplr function
struct Funcs : public list< pair<AType*, CFunction> > {
	CFunction find(AType* type) const {
		for (const_iterator f = begin(); f != end(); ++f)
			if (*f->first == *type)
				return f->second;
		return NULL;
	}
	void insert(AType* type, CFunction func) {
		push_back(make_pair(type, func));
	}
};

/// Closure (first-class function with captured lexical bindings)
struct ASTClosure : public ASTTuple {
	ASTClosure(ASTTuple* p, AST* b, const string& n="")
		: ASTTuple(0, p, b, NULL), name(n) {}
	bool      operator==(const AST& rhs) const { return this == &rhs; }
	string    str() const { return (format("%1%") % this).str(); }
	void      constrain(TEnv& tenv) const;
	void      lift(CEnv& cenv);
	CValue    compile(CEnv& cenv);
	ASTTuple* prot() const { return dynamic_cast<ASTTuple*>(at(1)); }
private:
	Funcs  funcs;
	string name;
};

/// Function call/application, e.g. "(func arg1 arg2)"
struct ASTCall : public ASTTuple {
	ASTCall(const SExp& e, const ASTTuple& t) : ASTTuple(t), exp(e) {}
	void   constrain(TEnv& tenv) const;
	void   lift(CEnv& cenv);
	CValue compile(CEnv& cenv);
	const SExp& exp;
};

/// Definition special form, e.g. "(def x 2)"
struct ASTDefinition : public ASTCall {
	ASTDefinition(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
	void   constrain(TEnv& tenv) const;
	void   lift(CEnv& cenv);
	CValue compile(CEnv& cenv);
};

/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct ASTIf : public ASTCall {
	ASTIf(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
};

/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct ASTPrimitive : public ASTCall {
	ASTPrimitive(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t), arg(ca) {}
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
	CArg* arg;
};

/// Cons special form, e.g. "(cons 1 2)"
struct ASTConsCall : public ASTCall {
	ASTConsCall(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
	AType* functionType(CEnv& cenv);
	void   constrain(TEnv& tenv) const;
	void   lift(CEnv& cenv);
	CValue compile(CEnv& cenv);
	static Funcs funcs;
};

/// Car special form, e.g. "(car p)"
struct ASTCarCall : public ASTCall {
	ASTCarCall(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
};

/// Cdr special form, e.g. "(cdr p)"
struct ASTCdrCall : public ASTCall {
	ASTCdrCall(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
	void   constrain(TEnv& tenv) const;
	CValue compile(CEnv& cenv);
};


/***************************************************************************
 * Parser - S-Expressions (SExp) -> AST Nodes (AST)                        *
 ***************************************************************************/

// Parse Time Environment (symbol table)
struct PEnv : private map<const string, ASTSymbol*> {
	typedef AST* (*PF)(PEnv&, const SExp&, CArg*); // Parse Function
	struct Handler { Handler(PF f, CArg* a=0) : func(f), arg(a) {} PF func; CArg* arg; };
	map<const string, Handler> aHandlers; ///< Atom parse functions
	map<const string, Handler> lHandlers; ///< List parse functions
	void reg(bool list, const string& s, const Handler& h) {
		(list ? lHandlers : aHandlers).insert(make_pair(sym(s)->str(), h));
	}
	const Handler* handler(bool list, const string& s) const {
		const map<const string, Handler>& handlers = list ? lHandlers : aHandlers;
		map<string, Handler>::const_iterator i = handlers.find(s);
		return (i != handlers.end()) ? &i->second : NULL;
	}
	ASTSymbol* sym(const string& s, Cursor c=Cursor()) {
		const const_iterator i = find(s);
		return ((i != end())
			? i->second
			: insert(make_pair(s, new ASTSymbol(s, c))).first->second);
	}
};

/// The fundamental parser method
static AST* parseExpression(PEnv& penv, const SExp& exp);

static ASTTuple
pmap(PEnv& penv, const SExp::List& l)
{
	ASTTuple ret(l.size());
	size_t n = 0;
	FOREACH(SExp::List::const_iterator, i, l)
		ret[n++] = parseExpression(penv, *i);
	return ret;
}

static AST*
parseExpression(PEnv& penv, const SExp& exp)
{
	if (exp.type == SExp::LIST) {
		if (exp.list.empty()) throw Error("call to empty list", exp.loc);
		if (exp.list.front().type == SExp::ATOM) {
			const PEnv::Handler* handler = penv.handler(true, exp.list.front().atom);
			if (handler) // Dispatch to list parse function
				return handler->func(penv, exp, handler->arg);
		}
		return new ASTCall(exp, pmap(penv, exp.list)); // Parse as regular call
	} else if (isdigit(exp.atom[0])) {
		if (exp.atom.find('.') == string::npos)
			return new ASTLiteral<int32_t>(strtol(exp.atom.c_str(), NULL, 10));
		else
			return new ASTLiteral<float>(strtod(exp.atom.c_str(), NULL));
	} else {
		const PEnv::Handler* handler = penv.handler(false, exp.atom);
		if (handler) // Dispatch to atom parse function
			return handler->func(penv, exp, handler->arg);
	}
	return penv.sym(exp.atom, exp.loc);
}

template<typename C>
inline AST*
parseCall(PEnv& penv, const SExp& exp, CArg* arg)
{
	return new C(exp, pmap(penv, exp.list), arg);
}

template<typename T>
inline AST*
parseLiteral(PEnv& penv, const SExp& exp, CArg* arg)
{
	return new ASTLiteral<T>(*reinterpret_cast<T*>(arg));
}

inline AST*
parseFn(PEnv& penv, const SExp& exp, CArg* arg)
{
	SExp::List::const_iterator a = exp.list.begin(); ++a;
	return new ASTClosure(
			new ASTTuple(pmap(penv, (*a++).list)),
			parseExpression(penv, *a++));
}


/***************************************************************************
 * Generic Lexical Environment                                             *
 ***************************************************************************/

template<typename K, typename V>
struct Env : public list< map<K,V> > {
	typedef map<K,V> Frame;
	Env() : list<Frame>(1) {}
	void push() { list<Frame>::push_front(Frame()); }
	void pop()  { assert(!this->empty()); list<Frame>::pop_front(); }
	const V& def(const K& k, const V& v) {
		typename Frame::iterator existing = this->front().find(k);
		if (existing != this->front().end() && existing->second != v)
			throw Error("redefinition");
		return (this->front()[k] = v);
	}
	V* ref(const K& name) {
		typename Frame::iterator s;
		for (typename Env::iterator i = this->begin(); i != this->end(); ++i)
			if ((s = i ->find(name)) != i->end())
				return &s->second;
		return 0;
	}
};


/***************************************************************************
 * Typing                                                                  *
 ***************************************************************************/

struct TSubst : public map<AType*, AType*> {
	TSubst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); }
};

/// Type-Time Environment
struct TEnv : public Env<const AST*,AType*> {
	TEnv(PEnv& p) : penv(p), varID(1) {}
	typedef list< pair<AType*, AType*> > Constraints;
	AType* var() { return new AType(varID++); }
	AType* type(const AST* ast) {
		AType** t = ref(ast);
		return t ? *t : def(ast, var());
	}
	AType* named(const string& name) {
		return *ref(penv.sym(name));
	}
	void constrain(const AST* o, AType* t) {
		assert(!dynamic_cast<const AType*>(o));
		constraints.push_back(make_pair(type(o), t));
	}
	void          solve() { apply(unify(constraints)); }
	void          apply(const TSubst& substs);
	static TSubst unify(const Constraints& c);

	PEnv&       penv;
	Constraints constraints;
	unsigned    varID;
};


/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

struct CEnvPimpl;

/// Compile-Time Environment
struct CEnv {
	CEnv(PEnv& p, TEnv& t, CEngine& engine);
	~CEnv();
	
	typedef Env<const AST*, AST*>   Code;
	typedef Env<const AST*, CValue> Vals;
	
	string gensym(const char* s="_") { return (format("%s%d") % s % symID++).str(); }
	void   push() { code.push(); vals.push(); }
	void   pop()  { code.pop();  vals.pop();  }
	void   precompile(AST* obj, CValue value) { vals.def(obj, value); }
	CValue compile(AST* obj);
	void   optimise(CFunction f);
	void   write(std::ostream& os);
	
	CEngine&  engine;
	PEnv&     penv;
	TEnv      tenv;
	Code      code;
	Vals      vals;
	unsigned  symID;
	CFunction alloc;

private:
	CEnvPimpl* _pimpl;
};


/***************************************************************************
 * EVAL/REPL/MAIN                                                          *
 ***************************************************************************/

void  initLang(PEnv& penv, TEnv& tenv);
CEnv* newCenv(PEnv& penv, TEnv& tenv);
int   eval(CEnv& cenv, const string& name, istream& is);
int   repl(CEnv& cenv);

#endif // TUPLR_HPP