/* A Trivial LLVM LISP
 * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net>
 *
 * Parts from the Kaleidoscope tutorial <http://llvm.org/docs/tutorial/>
 * by Chris Lattner and Erick Tryzelaar
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with This program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <iostream>
#include <list>
#include <map>
#include <stack>
#include <string>
#include <vector>
#include <sstream>
#include "llvm/Analysis/Verifier.h"
#include "llvm/DerivedTypes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/Module.h"
#include "llvm/ModuleProvider.h"
#include "llvm/PassManager.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Transforms/Scalar.h"

using namespace llvm;
using namespace std;


/***************************************************************************
 * S-Expression Lexer - Read text and output nested lists of strings       *
 ***************************************************************************/

struct SExp {
	SExp()                         : type(LIST)          {}
	SExp(const std::list<SExp>& l) : type(LIST), list(l) {}
	SExp(const std::string&     s) : type(ATOM), atom(s) {}
	enum { ATOM, LIST } type;
	std::string         atom;
	std::list<SExp>     list;
};

struct SyntaxError : public std::exception {
	SyntaxError(const char* m) : msg(m) {}
	const char* what() const throw() { return msg; }
	const char* msg;
};

static SExp
readExpression(std::istream& in)
{
	stack<SExp> stk;
	string      tok;

#define APPEND_TOK() \
	if (stk.empty()) return tok; else stk.top().list.push_back(SExp(tok))

	while (char ch = in.get()) {
		switch (ch) {
		case EOF:
			return SExp();
		case ' ': case '\t': case '\n':
			if (tok == "")
				continue;
			else
				APPEND_TOK();
			tok = "";
			break;
		case '"':
			do { tok.push_back(ch); } while ((ch = in.get()) != '"');
			tok.push_back('"');
			APPEND_TOK();
			tok = "";
			break;
		case '(':
			stk.push(SExp());
			break;
		case ')':
			switch (stk.size()) {
			case 0:
				throw SyntaxError("Missing '('");
				break;
			case 1:
				if (tok != "") stk.top().list.push_back(SExp(tok));
				return stk.top();
			default:
				if (tok != "") stk.top().list.push_back(SExp(tok));
				SExp l = stk.top();
				stk.pop();
				stk.top().list.push_back(l);
			}
			tok = "";
			break;
		default:
			tok.push_back(ch);
		}
	}

	switch (stk.size()) {
	case 0:  return tok;       break;
	case 1:  return stk.top(); break;
	default: throw SyntaxError("Missing ')'");
	}
}



/***************************************************************************
 * Abstract Syntax Tree                                                    *
 ***************************************************************************/

struct CEnv; ///< Compile Time Environment

/// Base class for all AST nodes
struct AST {
	virtual ~AST() {}
	virtual Value* Codegen(CEnv& cenv) = 0;
	virtual bool evaluatable() const { return true; }
};

/// Numeric literal, e.g. "1.0"
struct ASTNumber : public AST {
	ASTNumber(double val) : _val(val) {}
	virtual Value* Codegen(CEnv& cenv);
private:
	double _val;
};

/// Symbol, e.g. "a"
struct ASTSymbol : public AST {
	ASTSymbol(const string& name) : _name(name) {}
	virtual Value* Codegen(CEnv& cenv);
private:
	string _name;
};

/// Function call/application, e.g. "(func arg1 arg2)"
struct ASTCall : public AST {
	ASTCall(const string& n, vector<AST*>& a) : _name(n), _args(a) {}
	virtual Value* Codegen(CEnv& cenv);
protected:
	string       _name;
	vector<AST*> _args;
};

/// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))"
struct ASTDefinition : public ASTCall {
	ASTDefinition(const string& n, vector<AST*> a) : ASTCall(n, a) {}
	virtual Value* Codegen(CEnv& cenv);
};

/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct ASTIf : public ASTCall {
	ASTIf(const string& n, vector<AST*>& a) : ASTCall(n, a) {}
	virtual Value* Codegen(CEnv& cenv);
};

/// Primitive (builtin arithmetic function)
struct ASTPrimitive : public ASTCall {
	ASTPrimitive(const string& n, vector<AST*>& a) : ASTCall(n, a) {}
	virtual Value* Codegen(CEnv& cenv);
};

/// Function prototype
struct ASTPrototype : public AST {
	ASTPrototype(const string& n, const vector<string>& p=vector<string>())
			: _name(n), _params(p) {}
	virtual bool evaluatable() const { return false; }
	Value*       Codegen(CEnv& cenv) { return Funcgen(cenv); }
	Function*    Funcgen(CEnv& cenv);
private:
	string         _name;
	vector<string> _params;
};

/// Function definition
struct ASTFunction : public AST {
	ASTFunction(ASTPrototype* p, AST* b) : _proto(p), _body(b) {}
	virtual bool evaluatable() const { return false; }
	Value*       Codegen(CEnv& cenv) { return Funcgen(cenv); }
	Function*    Funcgen(CEnv& cenv);
private:
	ASTPrototype* _proto;
	AST*          _body;
};



/***************************************************************************
 * Parser - Transform S-Expressions into AST nodes                         *
 ***************************************************************************/

static AST* parseExpression(const SExp& exp);

/// numberexpr ::= number
static AST*
parseNumber(const SExp& exp)
{
	assert(exp.type == SExp::ATOM);
	return new ASTNumber(strtod(exp.atom.c_str(), NULL));
}

/// identifierexpr ::= identifier
static AST*
parseSymbol(const SExp& exp)
{
	assert(exp.type == SExp::ATOM);
	return new ASTSymbol(exp.atom);
}

/// prototype ::= (name [arg*])
static ASTPrototype*
parsePrototype(const SExp& exp)
{
	list<SExp>::const_iterator i = exp.list.begin();
	const string& name = i->atom;
	
	vector<string> args;
	for (++i; i != exp.list.end(); ++i)
		if (i->type == SExp::ATOM)
			args.push_back(i->atom);
		else
			throw SyntaxError("Expected parameter name, found list");

	return new ASTPrototype(name, args);
}

/// callexpr ::= (expression [...])
static AST*
parseCall(const SExp& exp)
{
	if (exp.list.empty())
		return NULL;

	list<SExp>::const_iterator i = exp.list.begin();
	const string& name = i->atom;
	
	if (name == "def" && (++i)->type == SExp::LIST) {
		ASTPrototype* proto = parsePrototype(*i++);
		AST*          body  = parseExpression(*i++);
		return new ASTFunction(proto, body);
	}

	vector<AST*> args;
	for (++i; i != exp.list.end(); ++i)
		args.push_back(parseExpression(*i));
	
	if (name.length() == 1) {
		switch (name[0]) {
		case '+': case '-': case '*': case '/':
		case '%': case '&': case '|': case '^':
			return new ASTPrimitive(name, args);
		}
	} else if (name == "if") {
		return new ASTIf(name, args);
	} else if (name == "def") {
		return new ASTDefinition(name, args);
	} else if (name == "foreign") {
		return parsePrototype(*++i++);
	}

	return new ASTCall(name, args);
}

static AST*
parseExpression(const SExp& exp)
{
	if (exp.type == SExp::LIST) {
		return parseCall(exp);
	} else if (isalpha(exp.atom[0])) {
		return parseSymbol(exp);
	} else if (isdigit(exp.atom[0])) {
		return parseNumber(exp);
	} else {
		throw SyntaxError("Illegal atom");
	}
}


/***************************************************************************
 * Code Generation                                                         *
 ***************************************************************************/

/// Compile-time environment
struct CEnv {
	CEnv(Module* m, const TargetData* target)
		: module(m), provider(module), fpm(&provider), id(0)
	{
		// Set up the optimizer pipeline.
		// Register info about how the target lays out data structures.
		fpm.add(new TargetData(*target));
		// Do simple "peephole" and bit-twiddling optimizations.
		fpm.add(createInstructionCombiningPass());
		// Reassociate expressions.
		fpm.add(createReassociatePass());
		// Eliminate Common SubExpressions.
		fpm.add(createGVNPass());
		// Simplify control flow graph (delete unreachable blocks, etc).
		fpm.add(createCFGSimplificationPass());
	}
	string gensym(const char* base="_") {
		ostringstream s; s << base << id++; return s.str();
	}
	IRBuilder<>            builder;
	Module*                module;
	ExistingModuleProvider provider;
	FunctionPassManager    fpm;
	map<string, Value*>    env;
	size_t                 id;
};

Value*
ASTNumber::Codegen(CEnv& cenv)
{
	return ConstantFP::get(APFloat(_val));
}

Value*
ASTSymbol::Codegen(CEnv& cenv)
{
	map<string, Value*>::const_iterator v = cenv.env.find(_name);
	if (v == cenv.env.end())
		throw SyntaxError((string("Undefined symbol '") + _name + "'").c_str());
	return v->second;
}

Value*
ASTDefinition::Codegen(CEnv& cenv)
{
	map<string, Value*>::const_iterator v = cenv.env.find(_name);
	if (v != cenv.env.end()) throw SyntaxError("Symbol redefinition");
	if (_args.empty())       throw SyntaxError("Empty definition");
	Value* valCode = _args[0]->Codegen(cenv);
	cenv.env[_name] = valCode;
	return valCode;
}

Value*
ASTCall::Codegen(CEnv& cenv)
{
	Function* f = cenv.module->getFunction(_name);
	if (!f)                            throw SyntaxError("Undefined function");
	if (f->arg_size() != _args.size()) throw SyntaxError("Illegal arguments");

	vector<Value*> params;
	for (size_t i = 0; i != _args.size(); ++i)
		params.push_back(_args[i]->Codegen(cenv));

	return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp");
}

Value*
ASTIf::Codegen(CEnv& cenv)
{
	Value* condV = _args[0]->Codegen(cenv);

	// Convert condition to a bool by comparing equal to 0.0.
	condV = cenv.builder.CreateFCmpONE(
			condV, ConstantFP::get(APFloat(0.0)), "ifcond");

	Function* parent = cenv.builder.GetInsertBlock()->getParent();

	// Create blocks for the then and else cases.
	// Insert the 'then' block at the end of the function.
	BasicBlock* thenBB = BasicBlock::Create("then", parent);
	BasicBlock* elseBB = BasicBlock::Create("else");
	BasicBlock* mergeBB = BasicBlock::Create("ifcont");

	cenv.builder.CreateCondBr(condV, thenBB, elseBB);

	// Emit then value.
	cenv.builder.SetInsertPoint(thenBB);
	Value* thenV = _args[1]->Codegen(cenv);

	cenv.builder.CreateBr(mergeBB);
	// Codegen of 'Then' can change the current block, update thenBB
	thenBB = cenv.builder.GetInsertBlock();

	// Emit else block.
	parent->getBasicBlockList().push_back(elseBB);
	cenv.builder.SetInsertPoint(elseBB);
	Value* elseV = _args[2]->Codegen(cenv);

	cenv.builder.CreateBr(mergeBB);
	// Codegen of 'Else' can change the current block, update elseBB
	elseBB = cenv.builder.GetInsertBlock();

	// Emit merge block.
	parent->getBasicBlockList().push_back(mergeBB);
	cenv.builder.SetInsertPoint(mergeBB);
	PHINode* pn = cenv.builder.CreatePHI(Type::DoubleTy, "iftmp");

	pn->addIncoming(thenV, thenBB);
	pn->addIncoming(elseV, elseBB);
	return pn;
}

Function*
ASTPrototype::Funcgen(CEnv& cenv)
{
	// Make the function type, e.g. double(double,double)
	vector<const Type*> argsT(_params.size(), Type::DoubleTy);
	FunctionType* FT = FunctionType::get(Type::DoubleTy, argsT, false);

	Function* f = Function::Create(
			FT, Function::ExternalLinkage, _name, cenv.module);

	// If F conflicted, there was already something named 'Name'.
	// If it has a body, don't allow redefinition.
	if (f->getName() != _name) {
		// Delete the one we just made and get the existing one.
		f->eraseFromParent();
		f = cenv.module->getFunction(_name);

		// If F already has a body, reject this.
		if (!f->empty()) throw SyntaxError("Function redefined");

		// If F took a different number of args, reject.
		if (f->arg_size() != _params.size())
			throw SyntaxError("Function redefined with mismatched arguments");
	}

	Function::arg_iterator a = f->arg_begin();
	for (size_t i = 0; i != _params.size(); ++a, ++i) {
		a->setName(_params[i]);   // Set name in generated code
		cenv.env[_params[i]] = a; // Add to environment
	}

	return f;
}

Function*
ASTFunction::Funcgen(CEnv& cenv)
{
	Function* f = _proto->Funcgen(cenv);

	// Create a new basic block to start insertion into.
	BasicBlock* bb = BasicBlock::Create("entry", f);
	cenv.builder.SetInsertPoint(bb);

	try {
		Value* retVal = _body->Codegen(cenv);
		cenv.builder.CreateRet(retVal); // Finish function
		verifyFunction(*f); // Validate generated code
		cenv.fpm.run(*f); // Optimize function
		return f;
	} catch (SyntaxError e) {
		f->eraseFromParent(); // Error reading body, remove function
		throw e;
	}

	return 0; // Never reached
}

Value*
ASTPrimitive::Codegen(CEnv& cenv)
{
	Instruction::BinaryOps op;
	assert(_name.length() == 1);
	switch (_name[0]) {
	case '+': op = Instruction::Add;  break;
	case '-': op = Instruction::Sub;  break;
	case '*': op = Instruction::Mul;  break;
	case '/': op = Instruction::FDiv; break;
	case '%': op = Instruction::FRem; break;
	case '&': op = Instruction::And;  break;
	case '|': op = Instruction::Or;   break;
	case '^': op = Instruction::Xor;  break;
	default: throw SyntaxError("Unknown primitive");
	}
	
	vector<Value*> params;
	for (vector<AST*>::const_iterator a = _args.begin(); a != _args.end(); ++a)
		params.push_back((*a)->Codegen(cenv));

	switch (params.size()) {
	case 0:
		throw SyntaxError("Primitive expects at least 1 argument");
	case 1:
		return params[0];
	default:
		Value* val = cenv.builder.CreateBinOp(op, params[0], params[1]);
		for (size_t i = 2; i < params.size(); ++i)
			val = cenv.builder.CreateBinOp(op, val, params[i]);
		return val;
	}
}


/***************************************************************************
 * REPL - Interactively compile, optimise, and execute code                *
 ***************************************************************************/

/// Read-Eval-Print-Loop
static void
repl(CEnv& cenv, ExecutionEngine* engine)
{
	while (1) {
		std::cout << "> ";
		std::cout.flush();
		SExp exp = readExpression(std::cin);
		if (exp.type == SExp::LIST && exp.list.empty())
			break;

		try {
			AST* ast = parseExpression(exp);
			if (!ast)
				continue;
			if (ast->evaluatable()) {
				ASTPrototype* proto = new ASTPrototype(cenv.gensym("repl"));
				ASTFunction*  func  = new ASTFunction(proto, ast);
				Function*     code  = func->Funcgen(cenv);
				void*         fp    = engine->getPointerToFunction(code);
				double (*f)() = (double (*)())fp;
				std::cout << f() << endl;
				//code->eraseFromParent();
			} else {
				Value* code = ast->Codegen(cenv); 
				std::cout << "Generated code:" << endl;
				code->dump();
			}
		} catch (SyntaxError e) {
			std::cerr << "Syntax error: " << e.what() << endl;
		}
	}
}


/***************************************************************************
 * Main driver code.
 ***************************************************************************/

int
main()
{
	Module*          module = new Module("interactive");
	ExecutionEngine* engine = ExecutionEngine::create(module);
	CEnv             cenv(module, engine->getTargetData());

	repl(cenv, engine);

	std::cout << "Generated code:" << endl;
	module->dump();
	return 0;
}