/* Tuplr: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

#ifndef TUPLR_HPP
#define TUPLR_HPP

#include <stdarg.h>
#include <iostream>
#include <list>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <vector>
#include <boost/format.hpp>

#define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i)
#define THROW_IF(cond, error, ...) { if (cond) throw Error(error, __VA_ARGS__); }

using namespace std;
using boost::format;


/***************************************************************************
 * Basic Utility Classes                                                   *
 ***************************************************************************/

/// Location in textual code
struct Cursor {
	Cursor(const string& n="", unsigned l=1, unsigned c=0) : name(n), line(l), col(c) {}
	operator bool() const { return !(line == 1 && col == 0); }
	string str() const { return (format("%1%:%2%:%3%") % name % line % col).str(); }
	string   name;
	unsigned line;
	unsigned col;
};

/// Compiler error
struct Error {
	Error(Cursor c, const string& m) : loc(c), msg(m) {}
	const string what() const { return (loc ? loc.str() + ": " : "") + "error: " + msg; }
	Cursor loc;
	string msg;
};

/// Expression ::= Atom | (SubExp*)
template<typename Atom>
struct Exp : public std::vector< Exp<Atom> > {
	Exp(Cursor c)                : type(LIST), loc(c)          {}
	Exp(Cursor c, const Atom& a) : type(ATOM), loc(c), atom(a) {}
	enum { ATOM, LIST } type;
	Cursor              loc;
	Atom                atom;
};

template<typename Atom>
extern ostream& operator<<(ostream& out, const Exp<Atom>& exp);

/// Lexical Address
struct LAddr {
	LAddr(unsigned u=0, unsigned o=0) : up(u), over(o) {}
	operator bool() const { return !(up == 0 && over == 0); }
	unsigned up, over;
};

/// Generic Lexical Environment
template<typename K, typename V>
struct Env : public list< vector< pair<K,V> > > {
	typedef vector< pair<K,V> > Frame;
	Env() : list<Frame>(1) {}
	virtual void push(Frame f=Frame()) { list<Frame>::push_front(f); }
	virtual void pop() { assert(!this->empty()); list<Frame>::pop_front(); }
	const V& def(const K& k, const V& v) {
		for (typename Frame::iterator b = this->begin()->begin(); b != this->begin()->end(); ++b)
			if (b->first == k)
				return (b->second = v);
		this->front().push_back(make_pair(k, v));
		return v;
	}
	V* ref(const K& key) {
		for (typename Env::iterator f = this->begin(); f != this->end(); ++f)
			for (typename Frame::iterator b = f->begin(); b != f->end(); ++b)
				if (b->first == key)
					return &b->second;
		return NULL;
	}
	LAddr lookup(const K& key) const {
		unsigned up = 0;
		for (typename Env::const_iterator f = this->begin(); f != this->end(); ++f, ++up)
			for (unsigned over = 0; over < f->size(); ++over)
				if ((*f)[over].first == key)
					return LAddr(up + 1, over + 1);
		return LAddr();
	}
	V& deref(LAddr addr) {
		assert(addr);
		typename Env::iterator f = this->begin();
		for (unsigned u = 1; u < addr.up; ++u, ++f) { assert(f != this->end()); }
		assert(f->size() > addr.over - 1);
		return (*f)[addr.over - 1].second;
	}
};


/***************************************************************************
 * Lexer: Text (istream) -> S-Expressions (SExp)                           *
 ***************************************************************************/

typedef Exp<string> SExp; ///< Textual S-Expression

SExp readExpression(Cursor& cur, std::istream& in);


/***************************************************************************
 * Backend Types                                                           *
 ***************************************************************************/

typedef void* CValue;    ///< Compiled value (opaque)
typedef void* CFunction; ///< Compiled function (opaque)
typedef void* CEngine;   ///< Compiler Engine (opaque)


/***************************************************************************
 * Garbage Collector                                                       *
 ***************************************************************************/

struct Object; ///< Object (AST nodes and runtime data)
struct CEnv;   ///< Compile-Time Environment

struct GC {
	typedef std::list<const Object*> Roots;
	typedef std::list<Object*>       Heap;
	void* alloc(size_t size);
	void  collect(CEnv& cenv, const Roots& roots);
	void  addRoot(const Object* obj) { if (obj) _roots.push_back(obj); }
	void  lock() { _roots.insert(_roots.end(), _heap.begin(), _heap.end()); }
	const Roots& roots() const { return _roots; }
private:
	Heap  _heap;
	Roots _roots;
};

struct Object {
	Object() : used(false) {}
	virtual ~Object() {}

	mutable bool used;

	static void* operator new(size_t size) { return pool.alloc(size); }
	static void  operator delete(void* ptr) {}
	static GC pool;
};


/***************************************************************************
 * Abstract Syntax Tree                                                    *
 ***************************************************************************/

struct Constraint;  ///< Type Constraint
struct TEnv;        ///< Type-Time Environment
struct Constraints;
struct Subst;
struct AST;

extern ostream& operator<<(ostream& out, const AST* ast);

/// Base class for all AST nodes
struct AST : public Object {
	AST(Cursor c=Cursor()) : loc(c) {}
	virtual bool   operator==(const AST& o) const = 0;
	virtual bool   contains(const AST* child) const { return false; }
	virtual void   constrain(TEnv& tenv, Constraints& c) const {}
	virtual void   lift(CEnv& cenv) {}
	virtual CValue compile(CEnv& cenv) = 0;
	string str() const { ostringstream ss; ss << this; return ss.str(); }
	template<typename T> T to()       { return dynamic_cast<T>(this); }
	template<typename T> T to() const { return dynamic_cast<T>(this); }
	template<typename T> T as() {
		T t = dynamic_cast<T>(this);
		if (!t) throw Error(loc, "internal error: bad cast");
		return t;
	}
	Cursor loc;
};

/// Literal value
template<typename VT>
struct ALiteral : public AST {
	ALiteral(VT v, Cursor c) : AST(c), val(v) {}
	bool operator==(const AST& rhs) const {
		const ALiteral<VT>* r = rhs.to<const ALiteral<VT>*>();
		return (r && (val == r->val));
	}
	void   constrain(TEnv& tenv, Constraints& c) const;
	CValue compile(CEnv& cenv);
	const VT val;
};

/// Symbol, e.g. "a"
struct ASymbol : public AST {
	bool   operator==(const AST& rhs) const { return this == &rhs; }
	void   constrain(TEnv& tenv, Constraints& c) const;
	CValue compile(CEnv& cenv);
	mutable LAddr  addr;
	const   string cppstr;
private:
	friend class PEnv;
	ASymbol(const string& s, Cursor c) : AST(c), cppstr(s) {}
};

/// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)"
struct ATuple : public AST, public vector<AST*> {
	ATuple(Cursor c, const vector<AST*>& v=vector<AST*>()) : AST(c), vector<AST*>(v) {}
	ATuple(Cursor c, AST* ast, ...) : AST(c) {
		if (!ast) return;
		va_list args; va_start(args, ast);
		push_back(ast);
		for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
			push_back(a);
		va_end(args);
	}
	void free() {
		FOREACH(const_iterator, p, *this)
			delete *p;
	}
	bool operator==(const AST& rhs) const {
		const ATuple* rt = rhs.to<const ATuple*>();
		if (!rt || rt->size() != size()) return false;
		const_iterator l = begin();
		FOREACH(const_iterator, r, *rt)
			if (!(*(*l++) == *(*r)))
				return false;
		return true;
	}
	void lift(CEnv& cenv) { FOREACH(iterator, t, *this) (*t)->lift(cenv); }
	bool contains(AST* child) const {
		if (*this == *child) return true;
		FOREACH(const_iterator, p, *this)
			if (**p == *child || (*p)->contains(child))
				return true;
		return false;
	}
	void   constrain(TEnv& tenv, Constraints& c) const;
	CValue compile(CEnv& cenv) { throw Error(loc, "tuple compiled"); }
};

/// Type Expression, e.g. "Int", "(Fn (Int Int) Float)"
struct AType : public ATuple {
	AType(ASymbol* s) : ATuple(s->loc), kind(PRIM), id(0) { push_back(s); }
	AType(Cursor c, unsigned i, LAddr a) : ATuple(c), kind(VAR), id(i) {}
	AType(Cursor c, AST* ast, ...) : ATuple(c), kind(EXPR), id(0) {
		if (!ast) return;
		va_list args; va_start(args, ast);
		if (ast)
			push_back(ast);
		for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
			push_back(a);
		va_end(args);
	}
	AType(const AType& copy) : ATuple(copy.loc), kind(copy.kind), id(copy.id) {
		for (AType::const_iterator i = copy.begin(); i != copy.end(); ++i) {
			AType* typ = (*i)->to<AType*>();
			if (typ) {
				push_back(new AType(*typ));
				continue;
			}
			ATuple* tup = (*i)->to<ATuple*>();
			if (tup)
				push_back(new ATuple(*tup));
			else
				push_back(*i);
		}
	}
	CValue compile(CEnv& cenv) { return NULL; }
	bool   var() const { return kind == VAR; }
	bool concrete() const {
		switch (kind) {
		case VAR:  return false;
		case PRIM: return true;
		case EXPR:
			FOREACH(const_iterator, t, *this) {
				AType* kid = (*t)->to<AType*>();
				if (kid && !kid->concrete())
					return false;
			}
		}
		return true;
	}
	bool operator<(const AType& rhs) const {
		return kind < rhs.kind || id < rhs.id;
	}
	bool operator==(const AST& rhs) const {
		const AType* rt = rhs.to<const AType*>();
		if (!rt || kind != rt->kind)
			return false;
		else
			switch (kind) {
			case VAR:  return id == rt->id;
			case PRIM: return at(0)->str() == rt->at(0)->str();
			case EXPR: return ATuple::operator==(rhs);
			}
		return false; // never reached
	}
	enum { VAR, PRIM, EXPR } kind;
	unsigned                 id;
};

struct typeLessThan {
	inline bool operator()(const AType* a, const AType* b) const { return *a < *b; }
};

/// Type substitution
struct Subst : public map<const AType*,AType*,typeLessThan> {
	Subst(AType* s=0, AType* t=0) { if (s && t) { assert(s != t); insert(make_pair(s, t)); } }
	static Subst compose(const Subst& delta, const Subst& gamma);
	AST* apply(AST* ast) const {
		AType* in = ast->to<AType*>();
		if (!in) return ast;
		if (in->kind == AType::EXPR) {
			AType* out = new AType(in->loc, NULL);
			for (size_t i = 0; i < in->size(); ++i)
				out->push_back(apply(in->at(i)));
			return out;
		} else {
			const_iterator i = find(in);
			if (i != end()) {
				return i->second;
			} else {
				return in;
			}
		}
	}
};

/// Lifted system functions (of various types) for a single Tuplr function
struct Funcs : public list< pair<AType*, CFunction> > {
	CFunction find(AType* type) const {
		for (const_iterator f = begin(); f != end(); ++f)
			if (*f->first == *type)
				return f->second;
		return NULL;
	}
};

/// Closure (first-class function with captured lexical bindings)
struct AClosure : public ATuple {
	AClosure(Cursor c, ASymbol* fn, ATuple* p, const string& n="")
		: ATuple(c, fn, p, NULL), name(n) {}
	bool    operator==(const AST& rhs) const { return this == &rhs; }
	void    constrain(TEnv& tenv, Constraints& c) const;
	void    lift(CEnv& cenv);
	void    liftCall(CEnv& cenv, const AType& argsT);
	CValue  compile(CEnv& cenv);
	ATuple* prot() const { return at(1)->to<ATuple*>(); }
	Funcs         funcs;
	mutable Subst subst;
	string        name;
};

/// Function call/application, e.g. "(func arg1 arg2)"
struct ACall : public ATuple {
	ACall(const SExp& e, const ATuple& t) : ATuple(e.loc, t) {}
	void   constrain(TEnv& tenv, Constraints& c) const;
	void   lift(CEnv& cenv);
	CValue compile(CEnv& cenv);
};

/// Definition special form, e.g. "(def x 2)"
struct ADefinition : public ACall {
	ADefinition(const SExp& e, const ATuple& t) : ACall(e, t) {}
	ASymbol* sym() const {
		ASymbol* sym = at(1)->to<ASymbol*>();
		if (!sym) {
			ATuple* tup = at(1)->to<ATuple*>();
			if (tup && !tup->empty())
				return tup->at(0)->to<ASymbol*>();
		}
		return sym;
	}
	void   constrain(TEnv& tenv, Constraints& c) const;
	void   lift(CEnv& cenv);
	CValue compile(CEnv& cenv);
};

/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct AIf : public ACall {
	AIf(const SExp& e, const ATuple& t) : ACall(e, t) {}
	void   constrain(TEnv& tenv, Constraints& c) const;
	CValue compile(CEnv& cenv);
};

/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct APrimitive : public ACall {
	APrimitive(const SExp& e, const ATuple& t) : ACall(e, t) {}
	void   constrain(TEnv& tenv, Constraints& c) const;
	CValue compile(CEnv& cenv);
};


/***************************************************************************
 * Parser: S-Expressions (SExp) -> AST Nodes (AST)                         *
 ***************************************************************************/

/// Parse Time Environment (really just a symbol table)
struct PEnv : private map<const string, ASymbol*> {
	typedef AST* (*PF)(PEnv&, const SExp&, void*); ///< Parse Function
	typedef SExp (*MF)(PEnv&, const SExp&); ///< Macro Function
	struct Handler { Handler(PF f, void* a=0) : func(f), arg(a) {} PF func; void* arg; };
	map<const string, Handler> aHandlers; ///< Atom parse functions
	map<const string, Handler> lHandlers; ///< List parse functions
	map<const string, MF>      macros; ///< Macro functions
	void reg(bool list, const string& s, const Handler& h) {
		(list ? lHandlers : aHandlers).insert(make_pair(sym(s)->str(), h));
	}
	const Handler* handler(bool list, const string& s) const {
		const map<const string, Handler>& handlers = list ? lHandlers : aHandlers;
		map<string, Handler>::const_iterator i = handlers.find(s);
		return (i != handlers.end()) ? &i->second : NULL;
	}
	void defmac(const string& s, const MF f) {
		macros.insert(make_pair(s, f));
	}
	MF mac(const string& s) const {
		map<string, MF>::const_iterator i = macros.find(s);
		return (i != macros.end()) ? i->second : NULL;
	}
	ASymbol* sym(const string& s, Cursor c=Cursor()) {
		const const_iterator i = find(s);
		if (i != end()) {
			return i->second;
		} else {
			ASymbol* sym = new ASymbol(s, c);
			insert(make_pair(s, sym));
			return sym;
		}
	}
	ATuple parseTuple(const SExp& e) {
		ATuple ret(e.loc, vector<AST*>(e.size()));
		size_t n = 0;
		FOREACH(SExp::const_iterator, i, e)
			ret[n++] = parse(*i);
		return ret;
	}
	AST* parse(const SExp& exp) {
		if (exp.type == SExp::LIST) {
			if (exp.empty()) throw Error(exp.loc, "call to empty list");
			if (exp.front().type == SExp::ATOM) {
				MF mf = mac(exp.front().atom);
				SExp expanded = (mf ? mf(*this, exp) : exp);

				const PEnv::Handler* h = handler(true, expanded.front().atom);
				if (h)
					return h->func(*this, expanded, h->arg);
			}
			return new ACall(exp, parseTuple(exp)); // Parse as regular call
		} else if (isdigit(exp.atom[0])) {
			if (exp.atom.find('.') == string::npos)
				return new ALiteral<int32_t>(strtol(exp.atom.c_str(), NULL, 10), exp.loc);
			else
				return new ALiteral<float>(strtod(exp.atom.c_str(), NULL), exp.loc);
		} else {
			const PEnv::Handler* h = handler(false, exp.atom);
			if (h)
				return h->func(*this, exp, h->arg);
		}
		return sym(exp.atom, exp.loc);
	}
};


/***************************************************************************
 * Typing                                                                  *
 ***************************************************************************/

struct Constraint : public pair<AType*,AType*> {
	Constraint(AType* a, AType* b, Cursor c) : pair<AType*,AType*>(a, b), loc(c) {}
	Cursor loc;
};

struct Constraints : public list<Constraint> {
	void constrain(TEnv& tenv, const AST* o, AType* t);
};

inline ostream& operator<<(ostream& out, const Constraints& c) {
	for (Constraints::const_iterator i = c.begin(); i != c.end(); ++i)
		out << i->first << " : " << i->second << endl;
	return out;
}

/// Type-Time Environment
struct TEnv : public Env< const ASymbol*, pair<AST*, AType*> > {
	TEnv(PEnv& p) : penv(p), varID(1) {}
	AType* fresh(const ASymbol* sym) {
		assert(sym);
		AType* ret = new AType(sym->loc, varID++, LAddr());
		def(sym, make_pair((AST*)NULL, ret));
		return ret;
	}
	AType* var(const AST* ast=0) {
		const ASymbol* sym = ast->to<const ASymbol*>();
		if (sym)
			return deref(lookup(sym)).second;

		Vars::iterator v = vars.find(ast);
		if (v != vars.end())
			return v->second;

		AType* ret = new AType(ast ? ast->loc : Cursor(), varID++, LAddr());
		if (ast)
			vars[ast] = ret;

		return ret;
	}
	AType* named(const string& name) {
		return ref(penv.sym(name))->second;
	}
	AST* resolve(AST* ast) {
		ASymbol* sym = ast->to<ASymbol*>();
		return (sym && sym->addr) ? ref(sym)->first : ast;
	}
	static Subst unify(const Constraints& c);

	typedef map<const AST*, AType*>            Vars;
	typedef map<const AClosure*, const AType*> GenericTypes;
	Vars         vars;
	GenericTypes genericTypes;
	PEnv&        penv;
	unsigned     varID;
};


/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

/// Compile-Time Environment
struct CEnv {
	CEnv(PEnv& p, TEnv& t, CEngine e, ostream& os=std::cout, ostream& es=std::cerr);
	~CEnv();

	typedef Env<const ASymbol*, AST*> Code;
	typedef Env<const AST*, CValue>   Vals;

	CEngine engine();
	string  gensym(const char* s="_") { return (format("%s%d") % s % symID++).str(); }
	void    push() { tenv.push(); vals.push(); }
	void    pop()  { tenv.pop();  vals.pop();  }
	void    precompile(AST* obj, CValue value) { vals.def(obj, value); }
	CValue  compile(AST* obj);
	void    optimise(CFunction f);
	void    write(std::ostream& os);
	void    lock(AST* ast) { Object::pool.addRoot(ast); Object::pool.addRoot(type(ast)); }
	AType*  type(AST* ast, const Subst& subst = Subst()) const {
		ASymbol* sym = ast->to<ASymbol*>();
		if (sym)
			return sym->addr ? tenv.deref(sym->addr).second : NULL;
		return tsubst.apply(subst.apply(tenv.vars[ast]))->to<AType*>();
	}
	void def(ASymbol* sym, AST* c, AType* t, CValue v) {
		tenv.def(sym, make_pair(c, t));
		vals.def(sym, v);
	}

	ostream&  out;
	ostream&  err;
	PEnv&     penv;
	TEnv&     tenv;
	Vals      vals;

	unsigned  symID;
	CFunction alloc;
	Subst     tsubst;

	map<string,string> args;

private:
	struct PImpl; ///< Private Implementation
	PImpl* _pimpl;
};


/***************************************************************************
 * EVAL/REPL/MAIN                                                          *
 ***************************************************************************/

void  pprint(std::ostream& out, const AST* ast);
void  initLang(PEnv& penv, TEnv& tenv);
CEnv* newCenv(PEnv& penv, TEnv& tenv);
void  freeCenv(CEnv* cenv);
int   eval(CEnv& cenv, const string& name, istream& is);
int   repl(CEnv& cenv);

#endif // TUPLR_HPP