/* Tuplr: A programming language
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Tuplr is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 *
 * Tuplr is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Tuplr.  If not, see <http://www.gnu.org/licenses/>.
 */

/** @file
 * @brief Interface and type definitions
 */

#ifndef TUPLR_HPP
#define TUPLR_HPP

#include <stdarg.h>
#include <string.h>
#include <iostream>
#include <list>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <vector>
#include <boost/format.hpp>

#define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i)
#define THROW_IF(cond, error, ...) { if (cond) throw Error(error, __VA_ARGS__); }

using namespace std;
using boost::format;


/***************************************************************************
 * Basic Utility Classes                                                   *
 ***************************************************************************/

/// Location in textual code
struct Cursor {
	Cursor(const string& n="", unsigned l=1, unsigned c=0) : name(n), line(l), col(c) {}
	operator bool() const { return !(line == 1 && col == 0); }
	string str() const { return (format("%1%:%2%:%3%") % name % line % col).str(); }
	string   name;
	unsigned line;
	unsigned col;
};

/// Compiler error
struct Error {
	Error(Cursor c, const string& m) : loc(c), msg(m) {}
	const string what() const { return (loc ? loc.str() + ": " : "") + "error: " + msg; }
	Cursor loc;
	string msg;
};

/// Generic Lexical Environment
template<typename K, typename V>
struct Env : public list< vector< pair<K,V> > > {
	typedef vector< pair<K,V> > Frame;
	Env() : list<Frame>(1) {}
	virtual void push(Frame f=Frame()) { list<Frame>::push_front(f); }
	virtual void pop() { assert(!this->empty()); list<Frame>::pop_front(); }
	const V& def(const K& k, const V& v) {
		for (typename Frame::iterator b = this->begin()->begin(); b != this->begin()->end(); ++b)
			if (b->first == k)
				return (b->second = v);
		this->front().push_back(make_pair(k, v));
		return v;
	}
	V* ref(const K& key) {
		for (typename Env::iterator f = this->begin(); f != this->end(); ++f)
			for (typename Frame::iterator b = f->begin(); b != f->end(); ++b)
				if (b->first == key)
					return &b->second;
		return NULL;
	}
};


/***************************************************************************
 * Lexer: Text (istream) -> S-Expressions (SExp)                           *
 ***************************************************************************/

class AST;
AST* readExpression(Cursor& cur, std::istream& in);


/***************************************************************************
 * Backend Types                                                           *
 ***************************************************************************/

typedef void* CVal;  ///< Compiled value (opaque)
typedef void* CFunc; ///< Compiled function (opaque)


/***************************************************************************
 * Garbage Collection                                                      *
 ***************************************************************************/

struct Object;

/// Garbage collector
struct GC {
	enum Tag {
		TAG_AST   = 1, ///< Abstract syntax tree node
		TAG_FRAME = 2  ///< Stack frame
	};
	typedef std::list<const Object*> Roots;
	typedef std::list<Object*>       Heap;
	GC(size_t pool_size);
	~GC();
	void* alloc(size_t size, Tag tag);
	void  collect(const Roots& roots);
	void  addRoot(const Object* obj) { assert(obj); _roots.push_back(obj); }
	void  lock() { _roots.insert(_roots.end(), _heap.begin(), _heap.end()); }
	const Roots& roots() const { return _roots; }
private:
	void* _pool;
	Heap  _heap;
	Roots _roots;
};

/// Garbage collected object (including AST and runtime data)
struct Object {
	struct Header {
		uint8_t mark;
		uint8_t tag;
	};

	/// Always allocated with pool.alloc, so this - sizeof(Header) is a valid Header*.
	inline Header* header() const { return (Header*)((char*)this - sizeof(Header)); }

	inline bool    marked()     const { return header()->mark != 0; }
	inline void    mark(bool b) const { header()->mark = 1; }
	inline GC::Tag tag()        const { return (GC::Tag)header()->tag; }

	static void* operator new(size_t size) { return pool.alloc(size, GC::TAG_AST); }
	static void  operator delete(void* ptr) {}
	static GC pool;
};


/***************************************************************************
 * Abstract Syntax Tree                                                    *
 ***************************************************************************/

struct TEnv;        ///< Type-Time Environment
struct Constraints; ///< Type Constraints
struct Subst;       ///< Type substitutions
struct CEnv;        ///< Compile-Time Environment

struct AST;
extern ostream& operator<<(ostream& out, const AST* ast);

/// Base class for all AST nodes
struct AST : public Object {
	AST(Cursor c=Cursor()) : loc(c) {}
	virtual ~AST() {}
	virtual bool   value() const { return true; }
	virtual bool   operator==(const AST& o) const = 0;
	virtual bool   contains(const AST* child) const { return false; }
	virtual void   constrain(TEnv& tenv, Constraints& c) const {}
	virtual AST*   cps(TEnv& tenv, AST* cont);
	virtual void   lift(CEnv& cenv) {}
	virtual CVal   compile(CEnv& cenv) = 0;
	string str() const { ostringstream ss; ss << this; return ss.str(); }
	template<typename T> T       to()       { return dynamic_cast<T>(this); }
	template<typename T> T const to() const { return dynamic_cast<T const>(this); }
	template<typename T> T as() {
		T t = dynamic_cast<T>(this);
		return t ? t : throw Error(loc, "internal error: bad cast");
	}
	template<typename T> T const as() const {
		T const t = dynamic_cast<T const>(this);
		return t ? t : throw Error(loc, "internal error: bad cast");
	}
	Cursor loc;
};

template<typename T>
static T* tup(Cursor c, AST* ast, ...)
{
	va_list args;
	va_start(args, ast);
	T* ret = new T(c, ast, args);
	va_end(args);
	return ret;
}

/// Literal value
template<typename T>
struct ALiteral : public AST {
	ALiteral(T v, Cursor c) : AST(c), val(v) {}
	bool operator==(const AST& rhs) const {
		const ALiteral<T>* r = rhs.to<const ALiteral<T>*>();
		return (r && (val == r->val));
	}
	void constrain(TEnv& tenv, Constraints& c) const;
	CVal compile(CEnv& cenv);
	const T val;
};

/// String, e.g. ""a""
struct AString : public AST, public std::string {
	AString(Cursor c, const string& s) : AST(c), std::string(s) {}
	bool operator==(const AST& rhs) const { return this == &rhs; }
	void constrain(TEnv& tenv, Constraints& c) const;
	CVal compile(CEnv& cenv) { return NULL; }
};

/// Symbol, e.g. "a"
struct ASymbol : public AST {
	bool operator==(const AST& rhs) const { return this == &rhs; }
	void constrain(TEnv& tenv, Constraints& c) const;
	CVal compile(CEnv& cenv);
	const string cppstr;
private:
	friend class PEnv;
	ASymbol(const string& s, Cursor c) : AST(c), cppstr(s) {}
};

/// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)"
struct ATuple : public AST {
	ATuple(Cursor c) : AST(c), _len(0), _vec(0) {}
	ATuple(const ATuple& exp) : AST(exp.loc), _len(exp._len) {
		_vec = (AST**)malloc(sizeof(AST*) * _len);
		memcpy(_vec, exp._vec, sizeof(AST*) * _len);
	}
	ATuple(Cursor c, AST* ast, va_list args) : AST(c), _len(0), _vec(0) {
		if (!ast) return;
		push_back(ast);
		for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
			push_back(a);
	}
	~ATuple() { free(_vec); }
	void push_back(AST* ast) {
		AST** newvec = (AST**)realloc(_vec, sizeof(AST*) * (_len + 1));
		newvec[_len++] = ast;
		_vec = newvec;
	}
	const AST* front()      const { assert(_len > 0); return _vec[0]; }
	const AST* at(size_t i) const { assert(i < _len); return _vec[i]; }
	AST*&      at(size_t i)       { assert(i < _len); return _vec[i]; }
	size_t     size()       const { return _len; }
	bool       empty()      const { return _len == 0; }

	typedef AST**        iterator;
	typedef AST* const * const_iterator;
	const_iterator begin() const { return _vec; }
	iterator       begin()       { return _vec; }
	const_iterator end()   const { return _vec + _len; }
	iterator       end()         { return _vec + _len; }
	bool value() const { return false; }
	bool operator==(const AST& rhs) const {
		const ATuple* rt = rhs.to<const ATuple*>();
		if (!rt || rt->size() != size()) return false;
		const_iterator l = begin();
		FOREACH(const_iterator, r, *rt)
			if (!(*(*l++) == *(*r)))
				return false;
		return true;
	}
	bool contains(AST* child) const {
		if (*this == *child) return true;
		FOREACH(const_iterator, p, *this)
			if (**p == *child || (*p)->contains(child))
				return true;
		return false;
	}
	void constrain(TEnv& tenv, Constraints& c) const;
	void lift(CEnv& cenv) { FOREACH(iterator, t, *this) (*t)->lift(cenv); }

	CVal compile(CEnv& cenv) { throw Error(loc, "tuple compiled"); }

private:
	size_t _len;
	AST**  _vec;
};

/// Type Expression, e.g. "Int", "(Fn (Int Int) Float)"
struct AType : public ATuple {
	AType(ASymbol* s) : ATuple(s->loc), kind(PRIM), id(0) { push_back(s); }
	AType(Cursor c, unsigned i) : ATuple(c), kind(VAR), id(i) {}
	AType(Cursor c) : ATuple(c), kind(EXPR), id(0) {}
	AType(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args), kind(EXPR), id(0) {}
	CVal compile(CEnv& cenv) { return NULL; }
	bool concrete() const {
		switch (kind) {
		case VAR:  return false;
		case PRIM: return at(0)->str() != "Nothing";
		case EXPR:
			FOREACH(const_iterator, t, *this) {
				AType* kid = (*t)->to<AType*>();
				if (kid && !kid->concrete())
					return false;
			}
		}
		return true;
	}
	bool operator==(const AST& rhs) const {
		const AType* rt = rhs.to<const AType*>();
		if (!rt || kind != rt->kind)
			return false;
		else
			switch (kind) {
			case VAR:  return id == rt->id;
			case PRIM: return at(0)->str() == rt->at(0)->str();
			case EXPR: return ATuple::operator==(rhs);
			}
		return false; // never reached
	}
	enum { VAR, PRIM, EXPR } kind;
	unsigned                 id;
};

/// Type substitution
struct Subst : public list< pair<const AType*,AType*> > {
	Subst(AType* s=0, AType* t=0) { if (s && t) { assert(s != t); push_back(make_pair(s, t)); } }
	static Subst compose(const Subst& delta, const Subst& gamma);
	void add(const AType* from, AType* to) { push_back(make_pair(from, to)); }
	const_iterator find(const AType* t) const {
		for (const_iterator j = begin(); j != end(); ++j)
			if (*j->first == *t)
				return j;
		return end();
	}
	AST* apply(AST* ast) const {
		AType* in = ast->to<AType*>();
		if (!in) return ast;
		if (in->kind == AType::EXPR) {
			AType* out = tup<AType>(in->loc, NULL);
			for (ATuple::iterator i = in->begin(); i != in->end(); ++i)
				out->push_back(apply(*i));
			return out;
		} else {
			const_iterator i = find(in);
			if (i != end()) {
				AST*   out  = i->second;
				AType* outT = out->to<AType*>();
				if (outT && outT->kind == AType::EXPR && !outT->concrete())
					out = apply(out);
				return out;
			} else {
				return in;
			}
		}
	}
};

inline ostream& operator<<(ostream& out, const Subst& s) {
	for (Subst::const_iterator i = s.begin(); i != s.end(); ++i)
		out << i->first << " => " << i->second << endl;
	return out;
}

/// Fn (first-class function with captured lexical bindings)
struct AFn : public ATuple {
	AFn(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args) {}
	bool operator==(const AST& rhs) const { return this == &rhs; }
	void constrain(TEnv& tenv, Constraints& c) const;
	AST* cps(TEnv& tenv, AST* cont);
	void lift(CEnv& cenv);
	CVal compile(CEnv& cenv);
	const ATuple* prot() const { return at(1)->to<const ATuple*>(); }
	ATuple*       prot()       { return at(1)->to<ATuple*>(); }
	/// System level implementations of this (polymorphic) fn
	struct Impls : public list< pair<AType*, CFunc> > {
		CFunc find(AType* type) const {
			for (const_iterator f = begin(); f != end(); ++f)
				if (*f->first == *type)
					return f->second;
			return NULL;
		}
	};
	Impls         impls;
	mutable Subst subst;
	string        name;
};

/// Function call/application, e.g. "(func arg1 arg2)"
struct ACall : public ATuple {
	ACall(const ATuple* exp) : ATuple(*exp) {}
	ACall(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args) {}
	void constrain(TEnv& tenv, Constraints& c) const;
	AST* cps(TEnv& tenv, AST* cont);
	void lift(CEnv& cenv);
	CVal compile(CEnv& cenv);
};

/// Definition special form, e.g. "(def x 2)"
struct ADef : public ACall {
	ADef(const ATuple* exp) : ACall(exp) {}
	ADef(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {}
	const ASymbol* sym() const {
		const ASymbol* sym = at(1)->to<const ASymbol*>();
		if (!sym) {
			const ATuple* tup = at(1)->to<const ATuple*>();
			if (tup && !tup->empty())
				return tup->at(0)->to<const ASymbol*>();
		}
		return sym;
	}
	void constrain(TEnv& tenv, Constraints& c) const;
	AST* cps(TEnv& tenv, AST* cont);
	void lift(CEnv& cenv);
	CVal compile(CEnv& cenv);
};

/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct AIf : public ACall {
	AIf(const ATuple* exp) : ACall(exp) {}
	AIf(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {}
	void constrain(TEnv& tenv, Constraints& c) const;
	AST* cps(TEnv& tenv, AST* cont);
	CVal compile(CEnv& cenv);
};

/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct APrimitive : public ACall {
	APrimitive(const ATuple* exp) : ACall(exp) {}
	bool value() const {
		ATuple::const_iterator i = begin();
		for (++i; i != end(); ++i)
			if (!(*i)->value())
				return false;;
		return true;
	}
	void constrain(TEnv& tenv, Constraints& c) const;
	AST* cps(TEnv& tenv, AST* cont);
	CVal compile(CEnv& cenv);
};


/***************************************************************************
 * Parser: S-Expressions (SExp) -> AST Nodes (AST)                         *
 ***************************************************************************/

/// Parse Time Environment (really just a symbol table)
struct PEnv : private map<const string, ASymbol*> {
	PEnv() : symID(0) {}
	typedef AST* (*PF)(PEnv&, const AST*, void*); ///< Parse Function
	typedef AST* (*MF)(PEnv&, const AST*); ///< Macro Function
	struct Handler { Handler(PF f, void* a=0) : func(f), arg(a) {} PF func; void* arg; };
	map<const string, Handler> aHandlers; ///< Atom parse functions
	map<const string, Handler> lHandlers; ///< List parse functions
	map<const string, MF>      macros; ///< Macro functions
	void reg(bool list, const string& s, const Handler& h) {
		(list ? lHandlers : aHandlers).insert(make_pair(sym(s)->str(), h));
	}
	const Handler* handler(bool list, const string& s) const {
		const map<const string, Handler>& handlers = list ? lHandlers : aHandlers;
		map<string, Handler>::const_iterator i = handlers.find(s);
		return (i != handlers.end()) ? &i->second : NULL;
	}
	void defmac(const string& s, const MF f) {
		macros.insert(make_pair(s, f));
	}
	MF mac(const AString& s) const {
		map<string, MF>::const_iterator i = macros.find(s);
		return (i != macros.end()) ? i->second : NULL;
	}
	string gensymstr(const char* s="_") { return (format("%s%d") % s % symID++).str(); }
	ASymbol* gensym(const char* s="_") { return sym(gensymstr(s)); }
	ASymbol* sym(const string& s, Cursor c=Cursor()) {
		const const_iterator i = find(s);
		if (i != end()) {
			return i->second;
		} else {
			ASymbol* sym = new ASymbol(s, c);
			insert(make_pair(s, sym));
			return sym;
		}
	}
	ATuple* parseTuple(const ATuple* e) {
		ATuple* ret = new ATuple(e->loc);
		FOREACH(ATuple::const_iterator, i, *e)
			ret->push_back(parse(*i));
		return ret;
	}
	AST* parse(const AST* exp) {
		const ATuple* tup = exp->to<const ATuple*>();
		if (tup) {
			if (tup->empty()) throw Error(exp->loc, "call to empty list");
			if (!tup->front()->to<const ATuple*>()) {
				MF mf = mac(*tup->front()->to<const AString*>());
				const AST*    expanded     = (mf ? mf(*this, exp) : exp);
				const ATuple* expanded_tup = expanded->to<const ATuple*>();
				const PEnv::Handler* h = handler(true, *expanded_tup->front()->to<const AString*>());
				if (h)
					return h->func(*this, expanded, h->arg);
			}
			ATuple* parsed_tup = parseTuple(tup); // FIXME: leak
			return new ACall(parsed_tup); // Parse as regular call
		}
		const AString* str = exp->to<const AString*>();
		assert(str);
		if (isdigit((*str)[0])) {
			const std::string& s = *str;
			if (s.find('.') == string::npos)
				return new ALiteral<int32_t>(strtol(s.c_str(), NULL, 10), exp->loc);
			else
				return new ALiteral<float>(strtod(s.c_str(), NULL), exp->loc);
		} else if ((*str)[0] == '\"') {
			return new AString(exp->loc, str->substr(1, str->length() - 2));
		} else {
			const PEnv::Handler* h = handler(false, *str);
			if (h)
				return h->func(*this, exp, h->arg);
		}
		return sym(*exp->to<const AString*>(), exp->loc);
	}
	unsigned symID;
};


/***************************************************************************
 * Typing                                                                  *
 ***************************************************************************/

/// Type constraint
struct Constraint : public pair<AType*,AType*> {
	Constraint(AType* a, AType* b, Cursor c) : pair<AType*,AType*>(a, b), loc(c) {}
	Cursor loc;
};

/// Type constraint set
struct Constraints : public list<Constraint> {
	void constrain(TEnv& tenv, const AST* o, AType* t);
	void replace(AType* s, AType* t);
};

inline ostream& operator<<(ostream& out, const Constraints& c) {
	for (Constraints::const_iterator i = c.begin(); i != c.end(); ++i)
		out << i->first << " : " << i->second << endl;
	return out;
}

/// Type-Time Environment
struct TEnv : public Env< const ASymbol*, pair<AST*, AType*> > {
	TEnv(PEnv& p) : penv(p), varID(1) {}
	AType* fresh(const ASymbol* sym) {
		AType* ret = new AType(sym->loc, varID++);
		def(sym, make_pair((AST*)NULL, ret));
		return ret;
	}
	AType* var(const AST* ast=0) {
		const ASymbol* sym = ast->to<const ASymbol*>();
		if (sym)
			return ref(sym)->second;

		Vars::iterator v = vars.find(ast);
		if (v != vars.end())
			return v->second;

		AType* ret = new AType(ast ? ast->loc : Cursor(), varID++);
		if (ast)
			vars[ast] = ret;

		return ret;
	}
	AType* named(const string& name) {
		return ref(penv.sym(name))->second;
	}
	AST* resolve(AST* ast) {
		const ASymbol*            sym = ast->to<const ASymbol*>();
		const pair<AST*, AType*>* rec = ref(sym);
		return rec ? rec->first : ast;
	}
	const AST* resolve(const AST* ast) {
		const ASymbol* sym = ast->to<const ASymbol*>();
		return sym ? ref(sym)->first : ast;
	}
	static Subst buildSubst(AType* fnT, const AType& argsT);

	typedef map<const AST*, AType*>       Vars;
	typedef map<const AFn*, const AType*> GenericTypes;
	Vars         vars;
	GenericTypes genericTypes;
	PEnv&        penv;
	unsigned     varID;
};

Subst unify(const Constraints& c);


/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

/// Compiler backend
struct Engine {
	virtual CFunc startFunction(
			CEnv&                cenv,
			const std::string&   name,
			const AType*         retT,
			const ATuple&        argsT,
			const vector<string> argNames=vector<string>()) = 0;

	virtual void  finishFunction(CEnv& cenv, CFunc f, const AType* retT, CVal ret) = 0;
	virtual void  eraseFunction(CEnv& cenv, CFunc f)                               = 0;
	virtual CFunc compileFunction(CEnv& cenv, AFn* fn, const AType& argsT)         = 0;
	virtual CVal  compileLiteral(CEnv& cenv, AST* lit)                             = 0;
	virtual CVal  compileCall(CEnv& cenv, CFunc f, const vector<CVal>& args)       = 0;
	virtual CVal  compilePrimitive(CEnv& cenv, APrimitive* prim)                   = 0;
	virtual CVal  compileIf(CEnv& cenv, AIf* aif)                                  = 0;
	virtual void  writeModule(CEnv& cenv, std::ostream& os)                        = 0;

	virtual const string call(CEnv& cenv, CFunc f, AType* retT) = 0;
};

Engine* tuplr_new_llvm_engine();
Engine* tuplr_new_c_engine();

/// Compile-Time Environment
struct CEnv {
	CEnv(PEnv& p, TEnv& t, Engine* e, ostream& os=std::cout, ostream& es=std::cerr)
		: out(os), err(es), penv(p), tenv(t), _engine(e)
	{}

	~CEnv() { Object::pool.collect(GC::Roots()); }

	typedef Env<const ASymbol*, CVal> Vals;

	Engine* engine() { return _engine; }
	void    push() { tenv.push(); vals.push(); }
	void    pop()  { tenv.pop();  vals.pop();  }
	void    lock(AST* ast) { Object::pool.addRoot(ast); Object::pool.addRoot(type(ast)); }
	AType*  type(AST* ast, const Subst& subst = Subst()) const {
		ASymbol* sym = ast->to<ASymbol*>();
		if (sym)
			return tenv.ref(sym)->second;
		return tsubst.apply(subst.apply(tenv.vars[ast]))->to<AType*>();
	}
	void def(const ASymbol* sym, AST* c, AType* t, CVal v) {
		tenv.def(sym, make_pair(c, t));
		vals.def(sym, v);
	}

	ostream& out;
	ostream& err;
	PEnv&    penv;
	TEnv&    tenv;
	Vals     vals;
	Subst    tsubst;

	map<string,string> args;

private:
	Engine* _engine;
};


/***************************************************************************
 * EVAL/REPL/MAIN                                                          *
 ***************************************************************************/

void pprint(std::ostream& out, const AST* ast);
void initLang(PEnv& penv, TEnv& tenv);
int  eval(CEnv& cenv, const string& name, istream& is, bool execute);
int  repl(CEnv& cenv);

#endif // TUPLR_HPP