/* A Trivial LLVM LISP * Copyright (C) 2008-2009 David Robillard <dave@drobilla.net> * * Parts from the Kaleidoscope tutorial <http://llvm.org/docs/tutorial/> * by Chris Lattner and Erick Tryzelaar * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with This program. If not, see <http://www.gnu.org/licenses/>. */ #include <iostream> #include <list> #include <map> #include <stack> #include <string> #include <vector> #include <sstream> #include "llvm/Analysis/Verifier.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" using namespace llvm; using namespace std; /*************************************************************************** * S-Expression Lexer - Read text and output nested lists of strings * ***************************************************************************/ struct SExp { SExp() : type(LIST) {} SExp(const std::list<SExp>& l) : type(LIST), list(l) {} SExp(const std::string& s) : type(ATOM), atom(s) {} enum { ATOM, LIST } type; std::string atom; std::list<SExp> list; }; struct SyntaxError : public std::exception { SyntaxError(const char* m) : msg(m) {} const char* what() const throw() { return msg; } const char* msg; }; static SExp readExpression(std::istream& in) { stack<SExp> stk; string tok; #define APPEND_TOK() \ if (stk.empty()) return tok; else stk.top().list.push_back(SExp(tok)) while (char ch = in.get()) { switch (ch) { case EOF: return SExp(); case ' ': case '\t': case '\n': if (tok == "") continue; else APPEND_TOK(); tok = ""; break; case '"': do { tok.push_back(ch); } while ((ch = in.get()) != '"'); tok.push_back('"'); APPEND_TOK(); tok = ""; break; case '(': stk.push(SExp()); break; case ')': switch (stk.size()) { case 0: throw SyntaxError("Missing '('"); break; case 1: if (tok != "") stk.top().list.push_back(SExp(tok)); return stk.top(); default: if (tok != "") stk.top().list.push_back(SExp(tok)); SExp l = stk.top(); stk.pop(); stk.top().list.push_back(l); } tok = ""; break; default: tok.push_back(ch); } } switch (stk.size()) { case 0: return tok; break; case 1: return stk.top(); break; default: throw SyntaxError("Missing ')'"); } } /*************************************************************************** * Abstract Syntax Tree * ***************************************************************************/ struct CEnv; ///< Compile Time Environment /// Base class for all AST nodes struct AST { virtual ~AST() {} virtual Value* Codegen(CEnv& cenv) = 0; virtual bool evaluatable() const { return true; } }; /// Numeric literal, e.g. "1.0" struct ASTNumber : public AST { ASTNumber(double val) : _val(val) {} virtual Value* Codegen(CEnv& cenv); private: double _val; }; /// Symbol, e.g. "a" struct ASTSymbol : public AST { ASTSymbol(const string& name) : _name(name) {} virtual Value* Codegen(CEnv& cenv); private: string _name; }; /// Function call/application, e.g. "(func arg1 arg2)" struct ASTCall : public AST { ASTCall(const string& n, vector<AST*>& a) : _name(n), _args(a) {} virtual Value* Codegen(CEnv& cenv); protected: string _name; vector<AST*> _args; }; /// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" struct ASTDefinition : public ASTCall { ASTDefinition(const string& n, vector<AST*> a) : ASTCall(n, a) {} virtual Value* Codegen(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { ASTIf(const string& n, vector<AST*>& a) : ASTCall(n, a) {} virtual Value* Codegen(CEnv& cenv); }; /// Primitive (builtin arithmetic function) struct ASTPrimitive : public ASTCall { ASTPrimitive(const string& n, vector<AST*>& a) : ASTCall(n, a) {} virtual Value* Codegen(CEnv& cenv); }; /// Function prototype struct ASTPrototype : public AST { ASTPrototype(const string& n, const vector<string>& p=vector<string>()) : _name(n), _params(p) {} virtual bool evaluatable() const { return false; } Value* Codegen(CEnv& cenv) { return Funcgen(cenv); } Function* Funcgen(CEnv& cenv); private: string _name; vector<string> _params; }; /// Function definition struct ASTFunction : public AST { ASTFunction(ASTPrototype* p, AST* b) : _proto(p), _body(b) {} virtual bool evaluatable() const { return false; } Value* Codegen(CEnv& cenv) { return Funcgen(cenv); } Function* Funcgen(CEnv& cenv); private: ASTPrototype* _proto; AST* _body; }; /*************************************************************************** * Parser - Transform S-Expressions into AST nodes * ***************************************************************************/ static AST* parseExpression(const SExp& exp); /// numberexpr ::= number static AST* parseNumber(const SExp& exp) { assert(exp.type == SExp::ATOM); return new ASTNumber(strtod(exp.atom.c_str(), NULL)); } /// identifierexpr ::= identifier static AST* parseSymbol(const SExp& exp) { assert(exp.type == SExp::ATOM); return new ASTSymbol(exp.atom); } /// prototype ::= (name [arg*]) static ASTPrototype* parsePrototype(const SExp& exp) { list<SExp>::const_iterator i = exp.list.begin(); const string& name = i->atom; vector<string> args; for (++i; i != exp.list.end(); ++i) if (i->type == SExp::ATOM) args.push_back(i->atom); else throw SyntaxError("Expected parameter name, found list"); return new ASTPrototype(name, args); } /// callexpr ::= (expression [...]) static AST* parseCall(const SExp& exp) { if (exp.list.empty()) return NULL; list<SExp>::const_iterator i = exp.list.begin(); const string& name = i->atom; if (name == "def" && (++i)->type == SExp::LIST) { ASTPrototype* proto = parsePrototype(*i++); AST* body = parseExpression(*i++); return new ASTFunction(proto, body); } vector<AST*> args; for (++i; i != exp.list.end(); ++i) args.push_back(parseExpression(*i)); if (name.length() == 1) { switch (name[0]) { case '+': case '-': case '*': case '/': case '%': case '&': case '|': case '^': return new ASTPrimitive(name, args); } } else if (name == "if") { return new ASTIf(name, args); } else if (name == "def") { return new ASTDefinition(name, args); } else if (name == "foreign") { return parsePrototype(*++i++); } return new ASTCall(name, args); } static AST* parseExpression(const SExp& exp) { if (exp.type == SExp::LIST) { return parseCall(exp); } else if (isalpha(exp.atom[0])) { return parseSymbol(exp); } else if (isdigit(exp.atom[0])) { return parseNumber(exp); } else { throw SyntaxError("Illegal atom"); } } /*************************************************************************** * Code Generation * ***************************************************************************/ /// Compile-time environment struct CEnv { CEnv(Module* m, const TargetData* target) : module(m), provider(module), fpm(&provider), id(0) { // Set up the optimizer pipeline. // Register info about how the target lays out data structures. fpm.add(new TargetData(*target)); // Do simple "peephole" and bit-twiddling optimizations. fpm.add(createInstructionCombiningPass()); // Reassociate expressions. fpm.add(createReassociatePass()); // Eliminate Common SubExpressions. fpm.add(createGVNPass()); // Simplify control flow graph (delete unreachable blocks, etc). fpm.add(createCFGSimplificationPass()); } string gensym(const char* base="_") { ostringstream s; s << base << id++; return s.str(); } IRBuilder<> builder; Module* module; ExistingModuleProvider provider; FunctionPassManager fpm; map<string, Value*> env; size_t id; }; Value* ASTNumber::Codegen(CEnv& cenv) { return ConstantFP::get(APFloat(_val)); } Value* ASTSymbol::Codegen(CEnv& cenv) { map<string, Value*>::const_iterator v = cenv.env.find(_name); if (v == cenv.env.end()) throw SyntaxError((string("Undefined symbol '") + _name + "'").c_str()); return v->second; } Value* ASTDefinition::Codegen(CEnv& cenv) { map<string, Value*>::const_iterator v = cenv.env.find(_name); if (v != cenv.env.end()) throw SyntaxError("Symbol redefinition"); if (_args.empty()) throw SyntaxError("Empty definition"); Value* valCode = _args[0]->Codegen(cenv); cenv.env[_name] = valCode; return valCode; } Value* ASTCall::Codegen(CEnv& cenv) { Function* f = cenv.module->getFunction(_name); if (!f) throw SyntaxError("Undefined function"); if (f->arg_size() != _args.size()) throw SyntaxError("Illegal arguments"); vector<Value*> params; for (size_t i = 0; i != _args.size(); ++i) params.push_back(_args[i]->Codegen(cenv)); return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } Value* ASTIf::Codegen(CEnv& cenv) { Value* condV = _args[0]->Codegen(cenv); // Convert condition to a bool by comparing equal to 0.0. condV = cenv.builder.CreateFCmpONE( condV, ConstantFP::get(APFloat(0.0)), "ifcond"); Function* parent = cenv.builder.GetInsertBlock()->getParent(); // Create blocks for the then and else cases. // Insert the 'then' block at the end of the function. BasicBlock* thenBB = BasicBlock::Create("then", parent); BasicBlock* elseBB = BasicBlock::Create("else"); BasicBlock* mergeBB = BasicBlock::Create("ifcont"); cenv.builder.CreateCondBr(condV, thenBB, elseBB); // Emit then value. cenv.builder.SetInsertPoint(thenBB); Value* thenV = _args[1]->Codegen(cenv); cenv.builder.CreateBr(mergeBB); // Codegen of 'Then' can change the current block, update thenBB thenBB = cenv.builder.GetInsertBlock(); // Emit else block. parent->getBasicBlockList().push_back(elseBB); cenv.builder.SetInsertPoint(elseBB); Value* elseV = _args[2]->Codegen(cenv); cenv.builder.CreateBr(mergeBB); // Codegen of 'Else' can change the current block, update elseBB elseBB = cenv.builder.GetInsertBlock(); // Emit merge block. parent->getBasicBlockList().push_back(mergeBB); cenv.builder.SetInsertPoint(mergeBB); PHINode* pn = cenv.builder.CreatePHI(Type::DoubleTy, "iftmp"); pn->addIncoming(thenV, thenBB); pn->addIncoming(elseV, elseBB); return pn; } Function* ASTPrototype::Funcgen(CEnv& cenv) { // Make the function type, e.g. double(double,double) vector<const Type*> argsT(_params.size(), Type::DoubleTy); FunctionType* FT = FunctionType::get(Type::DoubleTy, argsT, false); Function* f = Function::Create( FT, Function::ExternalLinkage, _name, cenv.module); // If F conflicted, there was already something named 'Name'. // If it has a body, don't allow redefinition. if (f->getName() != _name) { // Delete the one we just made and get the existing one. f->eraseFromParent(); f = cenv.module->getFunction(_name); // If F already has a body, reject this. if (!f->empty()) throw SyntaxError("Function redefined"); // If F took a different number of args, reject. if (f->arg_size() != _params.size()) throw SyntaxError("Function redefined with mismatched arguments"); } Function::arg_iterator a = f->arg_begin(); for (size_t i = 0; i != _params.size(); ++a, ++i) { a->setName(_params[i]); // Set name in generated code cenv.env[_params[i]] = a; // Add to environment } return f; } Function* ASTFunction::Funcgen(CEnv& cenv) { Function* f = _proto->Funcgen(cenv); // Create a new basic block to start insertion into. BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); try { Value* retVal = _body->Codegen(cenv); cenv.builder.CreateRet(retVal); // Finish function verifyFunction(*f); // Validate generated code cenv.fpm.run(*f); // Optimize function return f; } catch (SyntaxError e) { f->eraseFromParent(); // Error reading body, remove function throw e; } return 0; // Never reached } Value* ASTPrimitive::Codegen(CEnv& cenv) { Instruction::BinaryOps op; assert(_name.length() == 1); switch (_name[0]) { case '+': op = Instruction::Add; break; case '-': op = Instruction::Sub; break; case '*': op = Instruction::Mul; break; case '/': op = Instruction::FDiv; break; case '%': op = Instruction::FRem; break; case '&': op = Instruction::And; break; case '|': op = Instruction::Or; break; case '^': op = Instruction::Xor; break; default: throw SyntaxError("Unknown primitive"); } vector<Value*> params; for (vector<AST*>::const_iterator a = _args.begin(); a != _args.end(); ++a) params.push_back((*a)->Codegen(cenv)); switch (params.size()) { case 0: throw SyntaxError("Primitive expects at least 1 argument"); case 1: return params[0]; default: Value* val = cenv.builder.CreateBinOp(op, params[0], params[1]); for (size_t i = 2; i < params.size(); ++i) val = cenv.builder.CreateBinOp(op, val, params[i]); return val; } } /*************************************************************************** * REPL - Interactively compile, optimise, and execute code * ***************************************************************************/ /// Read-Eval-Print-Loop static void repl(CEnv& cenv, ExecutionEngine* engine) { while (1) { std::cout << "> "; std::cout.flush(); SExp exp = readExpression(std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; try { AST* ast = parseExpression(exp); if (!ast) continue; if (ast->evaluatable()) { ASTPrototype* proto = new ASTPrototype(cenv.gensym("repl")); ASTFunction* func = new ASTFunction(proto, ast); Function* code = func->Funcgen(cenv); void* fp = engine->getPointerToFunction(code); double (*f)() = (double (*)())fp; std::cout << f() << endl; //code->eraseFromParent(); } else { Value* code = ast->Codegen(cenv); std::cout << "Generated code:" << endl; code->dump(); } } catch (SyntaxError e) { std::cerr << "Syntax error: " << e.what() << endl; } } } /*************************************************************************** * Main driver code. ***************************************************************************/ int main() { Module* module = new Module("interactive"); ExecutionEngine* engine = ExecutionEngine::create(module); CEnv cenv(module, engine->getTargetData()); repl(cenv, engine); std::cout << "Generated code:" << endl; module->dump(); return 0; }