/* A Trivial LLVM LISP * Copyright (C) 2008-2009 David Robillard * * Parts from the Kaleidoscope tutorial * by Chris Lattner and Erick Tryzelaar * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with This program. If not, see . */ #include #include #include #include #include #include #include #include "llvm/Analysis/Verifier.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" #define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i) using namespace llvm; using namespace std; struct Error : public std::exception { Error(const char* m) : msg(m) {} const char* what() const throw() { return msg; } const char* msg; }; /*************************************************************************** * S-Expression Lexer :: text -> S-Expressions (SExp) * ***************************************************************************/ struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} }; struct SExp { SExp() : type(LIST) {} SExp(const std::list& l) : type(LIST), list(l) {} SExp(const std::string& s) : type(ATOM), atom(s) {} enum { ATOM, LIST } type; std::string atom; std::list list; }; static SExp readExpression(std::istream& in) { stack stk; string tok; #define APPEND_TOK() \ if (stk.empty()) return tok; else stk.top().list.push_back(SExp(tok)) while (char ch = in.get()) { switch (ch) { case EOF: return SExp(); case ' ': case '\t': case '\n': if (tok == "") continue; else APPEND_TOK(); tok = ""; break; case '"': do { tok.push_back(ch); } while ((ch = in.get()) != '"'); tok.push_back('"'); APPEND_TOK(); tok = ""; break; case '(': stk.push(SExp()); break; case ')': switch (stk.size()) { case 0: throw SyntaxError("Missing '('"); break; case 1: if (tok != "") stk.top().list.push_back(SExp(tok)); return stk.top(); default: if (tok != "") stk.top().list.push_back(SExp(tok)); SExp l = stk.top(); stk.pop(); stk.top().list.push_back(l); } tok = ""; break; default: tok.push_back(ch); } } switch (stk.size()) { case 0: return tok; break; case 1: return stk.top(); break; default: throw SyntaxError("Missing ')'"); } return SExp(); } /*************************************************************************** * Environment * ***************************************************************************/ class AST; class ASTSymbol; /// Generic Recursive Environment (stack of key:value dictionaries) template struct Env : public list< map > { Env() : list< map >(1) {} void push() { this->push_front(map()); } void push(const map& frame) { this->push_front(frame); } map& pop() { map& front = this->front(); this->pop_front(); return front; } void def(const K& k, const V& v) { if (this->front().find(k) != this->front().end()) throw SyntaxError("Redefinition"); this->front()[k] = v; } V* ref(const K& name) { typename Env::iterator i = this->begin(); for (; i != this->end(); ++i) { typename map::iterator s = i->find(name); if (s != i->end()) return &s->second; } return 0; } }; class PEnv; /// Compile-time environment struct CEnv { CEnv(PEnv& p, Module* m, const TargetData* target) : penv(p), module(m), emp(module), fpm(&emp), symID(0), tID(0) { // Set up the optimizer pipeline. // Register info about how the target lays out data structures. fpm.add(new TargetData(*target)); // Do simple "peephole" and bit-twiddling optimizations. fpm.add(createInstructionCombiningPass()); // Reassociate expressions. fpm.add(createReassociatePass()); // Eliminate Common SubExpressions. fpm.add(createGVNPass()); // Simplify control flow graph (delete unreachable blocks, etc). fpm.add(createCFGSimplificationPass()); } string gensym(const char* base="_") { ostringstream s; s << base << symID++; return s.str(); } typedef Env Code; typedef Env Vals; PEnv& penv; IRBuilder<> builder; Module* module; ExistingModuleProvider emp; FunctionPassManager fpm; unsigned symID; unsigned tID; Code code; Vals vals; }; /*************************************************************************** * Abstract Syntax Tree * ***************************************************************************/ struct AType; struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; struct CEnv; ///< Compile Time Environment /// Base class for all AST nodes struct AST { virtual ~AST() {} virtual string str(CEnv& cenv) const = 0; virtual AType* type(CEnv& cenv) = 0; virtual Value* compile(CEnv& cenv) = 0; virtual void lift(CEnv& cenv) {} }; /// Symbol, e.g. "a" struct ASTSymbol : public AST { ASTSymbol(const string& s) : cppstr(s) {} std::string str(CEnv&) const { return cppstr; } AType* type(CEnv& cenv); Value* compile(CEnv& cenv); private: const string cppstr; }; /// Tuple (heterogeneous sequence of known length), e.g. "(a b c)" struct ASTTuple : public AST { ASTTuple(vector t=vector()) : tup(t) {} string str(CEnv& cenv) const { string ret = "("; for (size_t i = 0; i != tup.size(); ++i) ret += tup[i]->str(cenv) + ((i != tup.size() - 1) ? " " : ""); ret.append(")"); return ret; } void lift(CEnv& cenv) { FOREACH(vector::iterator, t, tup) (*t)->lift(cenv); } AType* type(CEnv& cenv); Value* compile(CEnv& cenv) { return NULL; } vector tup; }; /// TExpr ::= (TName TExpr*) | ?Num struct AType : public ASTTuple { AType(unsigned i) : var(true), ctype(0), id(id) {} AType(const string& n, const Type* t) : var(false), ctype(t) { tup.push_back(new ASTSymbol(n)); } AType(const vector& t) : ASTTuple(t), var(false), ctype(0) {} inline bool operator==(const AType& t) const { return tup[0] == t.tup[0]; } inline bool operator!=(const AType& t) const { return tup[0] != t.tup[0]; } string str(CEnv& cenv) const { return var ? "?" : ASTTuple::str(cenv); } AType* type(CEnv& cenv) { return this; } Value* compile(CEnv& cenv) { return NULL; } bool var; const Type* ctype; unsigned id; }; /// Literal template struct ASTLiteral : public AST { ASTLiteral(VT v) : val(v) {} string str(CEnv& env) const { return "(Literal)"; } AType* type(CEnv& cenv); Value* compile(CEnv& cenv); const VT val; }; #define LITERAL(CT, VT, NAME, COMPILED) \ template<> string \ ASTLiteral::str(CEnv& cenv) const { return NAME; } \ template<> AType* \ ASTLiteral::type(CEnv& cenv) { return new AType(NAME, VT); } \ template<> Value* \ ASTLiteral::compile(CEnv& cenv) { return (COMPILED); } /// Literal template instantiations LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, true)); LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val)); LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); typedef unsigned UD; // User Data passed to registered parse functions // Parse Time Environment (symbol table) struct PEnv : private map { typedef AST* (*PF)(PEnv&, const list&, UD); // Parse Function struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; map parsers; void reg(const ASTSymbol* s, const Parser& p) { parsers.insert(make_pair(s, p)); } const Parser* parser(const ASTSymbol* s) const { map::const_iterator i = parsers.find(s); return (i != parsers.end()) ? &i->second : NULL; } ASTSymbol* sym(const string& s) { const const_iterator i = find(s); return ((i != end()) ? i->second : insert(make_pair(s, new ASTSymbol(s))).first->second); } }; /// Closure (first-class function with captured lexical bindings) struct ASTClosure : public AST { ASTClosure(ASTTuple* p, AST* b) : prot(p), body(b), func(0) {} string str(CEnv& env) const { return "(fn)"; } AType* type(CEnv& cenv) { vector texp(3); texp[0] = cenv.penv.sym("Fn"); texp[1] = prot; texp[2] = body; return new AType(texp); } Value* compile(CEnv& cenv); void lift(CEnv& cenv); ASTTuple* const prot; AST* const body; vector bindings; private: Function* func; }; /// Function call/application, e.g. "(func arg1 arg2)" struct ASTCall : public ASTTuple { ASTCall(const vector& t) : ASTTuple(t) {} AType* type(CEnv& cenv) { AST* callee = tup[0]; ASTSymbol* sym = dynamic_cast(tup[0]); if (sym) { AST** val = cenv.code.ref(sym); if (val) callee = *val; } ASTClosure* c = dynamic_cast(callee); if (!c) throw TypeError("Call to non-closure"); return c->body->type(cenv); } void lift(CEnv& cenv); Value* compile(CEnv& cenv); }; /// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" struct ASTDefinition : public ASTCall { ASTDefinition(const vector& c) : ASTCall(c) {} AType* type(CEnv& cenv) { return tup[2]->type(cenv); } Value* compile(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { ASTIf(const vector& c) : ASTCall(c) {} AType* type(CEnv& cenv) { AType* cT = tup[1]->type(cenv); AType* tT = tup[2]->type(cenv); AType* eT = tup[3]->type(cenv); if (cT->ctype != Type::Int1Ty) throw TypeError("If condition is not a boolean"); if (*tT != *eT) throw TypeError("If branches have different types"); return tT; } Value* compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { ASTPrimitive(const vector& c, Instruction::BinaryOps o) : ASTCall(c), op(o) {} AType* type(CEnv& cenv) { if (tup.size() <= 1) throw SyntaxError("Primitive call with no arguments"); return tup[1]->type(cenv); // FIXME: Ensure argument types are equivalent } Value* compile(CEnv& cenv); Instruction::BinaryOps op; }; AType* ASTTuple::type(CEnv& cenv) { vector texp; FOREACH(vector::const_iterator, p, tup) texp.push_back((*p)->type(cenv)); return new AType(texp); } /*************************************************************************** * Parser - S-Expressions (SExp) -> AST Nodes (AST) * ***************************************************************************/ /// The fundamental parser method static AST* parseExpression(PEnv& penv, const SExp& exp) { if (exp.type == SExp::LIST) { // Parse head of list if (exp.list.empty()) throw SyntaxError("Call to empty list"); vector code(exp.list.size()); code[0] = parseExpression(penv, exp.list.front()); // Dispatch to parse function if possible ASTSymbol* sym = dynamic_cast(code[0]); const PEnv::Parser* handler = sym ? penv.parser(sym) : NULL; if (handler) return handler->pf(penv, exp.list, handler->ud); // Parse as a regular call list::const_iterator i = exp.list.begin(); ++i; for (size_t n = 1; i != exp.list.end(); ++i) code[n++] = parseExpression(penv, *i); return new ASTCall(code); } else if (isdigit(exp.atom[0])) { if (exp.atom.find('.') == string::npos) return new ASTLiteral(strtol(exp.atom.c_str(), NULL, 10)); else return new ASTLiteral(strtod(exp.atom.c_str(), NULL)); } return penv.sym(exp.atom); } // Special forms static vector pmap(PEnv& penv, const list& l) { vector code(l.size()); size_t n = 0; for (list::const_iterator i = l.begin(); i != l.end(); ++i) code[n++] = parseExpression(penv, *i); return code; } static AST* parseIf(PEnv& penv, const list& c, UD) { return new ASTIf(pmap(penv, c)); } static AST* parseDef(PEnv& penv, const list& c, UD) { return new ASTDefinition(pmap(penv, c)); } static AST* parsePrim(PEnv& penv, const list& c, UD data) { return new ASTPrimitive(pmap(penv, c), (Instruction::BinaryOps)data); } static ASTTuple* parsePrototype(PEnv& penv, const SExp& e, UD) { return new ASTTuple(pmap(penv, e.list)); } static AST* parseFn(PEnv& penv, const list& c, UD) { list::const_iterator a = c.begin(); ++a; return new ASTClosure( parsePrototype(penv, *a++, 0), parseExpression(penv, *a++)); } /*************************************************************************** * Code Generation * ***************************************************************************/ struct CompileError : public Error { CompileError(const char* m) : Error(m) {} }; static Function* compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT) { Function::LinkageTypes linkage = Function::ExternalLinkage; const vector& texp = prot.type(cenv)->tup; vector cprot; for (size_t i = 0; i < texp.size(); ++i) { const Type* t = texp[i]->type(cenv)->ctype; if (!t) throw CompileError("Function prototype contains NULL"); cprot.push_back(t); } FunctionType* fT = FunctionType::get(retT, cprot, false); Function* f = Function::Create(fT, linkage, name, cenv.module); if (f->getName() != name) { f->eraseFromParent(); throw CompileError("Function redefined"); } // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); for (size_t i = 0; i != prot.tup.size(); ++a, ++i) a->setName(prot.tup[i]->str(cenv)); return f; } AType* ASTSymbol::type(CEnv& cenv) { AST** t = cenv.code.ref(this); return t ? (*t)->type(cenv) : new AType(cenv.tID++); } Value* ASTSymbol::compile(CEnv& cenv) { Value** v = cenv.vals.ref(this); if (v) return *v; AST** c = cenv.code.ref(this); if (c) { Value* v = (*c)->compile(cenv); cenv.vals.def(this, v); return v; } throw SyntaxError((string("Undefined symbol '") + cppstr + "'").c_str()); } Value* ASTDefinition::compile(CEnv& cenv) { if (tup.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments"); const ASTSymbol* sym = dynamic_cast(tup[1]); if (!sym) throw SyntaxError("Definition name is not a symbol"); Value* val = tup[2]->compile(cenv); cenv.code.def(sym, tup[2]); cenv.vals.def(sym, val); return val; } void ASTCall::lift(CEnv& cenv) { ASTClosure* c = dynamic_cast(tup[0]); if (!c) { AST** val = cenv.code.ref(tup[0]); c = (val) ? dynamic_cast(*val) : c; } if (!c) { ASTTuple::lift(cenv); return; } std::cout << "Lifting call to closure" << endl; // Lift arguments for (size_t i = 1; i < tup.size(); ++i) tup[i]->lift(cenv); // Extend environment with bound and typed parameters cenv.code.push(); if (c->prot->tup.size() != tup.size() - 1) throw CompileError("Call to closure with mismatched arguments"); for (size_t i = 1; i < tup.size(); ++i) cenv.code.def(c->prot->tup[i-1], tup[i]); // Lift callee closure tup[0]->lift(cenv); cenv.code.pop(); } Value* ASTCall::compile(CEnv& cenv) { ASTClosure* c = dynamic_cast(tup[0]); if (!c) { AST** val = cenv.code.ref(tup[0]); c = (val) ? dynamic_cast(*val) : c; } if (!c) throw CompileError("Call to non-closure"); Value* v = c->compile(cenv); if (!v) throw SyntaxError("Callee failed to compile"); Function* f = dynamic_cast(c->compile(cenv)); if (!f) throw SyntaxError("Callee compiled to non-function"); vector params; for (size_t i = 1; i < tup.size(); ++i) params.push_back(tup[i]->compile(cenv)); return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } Value* ASTIf::compile(CEnv& cenv) { Value* condV = tup[1]->compile(cenv); Function* parent = cenv.builder.GetInsertBlock()->getParent(); // Create blocks for the then and else cases. // Insert the 'then' block at the end of the function. BasicBlock* thenBB = BasicBlock::Create("then", parent); BasicBlock* elseBB = BasicBlock::Create("else"); BasicBlock* mergeBB = BasicBlock::Create("ifcont"); cenv.builder.CreateCondBr(condV, thenBB, elseBB); // Emit then value. cenv.builder.SetInsertPoint(thenBB); Value* thenV = tup[2]->compile(cenv); cenv.builder.CreateBr(mergeBB); // compile of 'Then' can change the current block, update thenBB thenBB = cenv.builder.GetInsertBlock(); // Emit else block. parent->getBasicBlockList().push_back(elseBB); cenv.builder.SetInsertPoint(elseBB); Value* elseV = tup[3]->compile(cenv); cenv.builder.CreateBr(mergeBB); // compile of 'Else' can change the current block, update elseBB elseBB = cenv.builder.GetInsertBlock(); // Emit merge block. parent->getBasicBlockList().push_back(mergeBB); cenv.builder.SetInsertPoint(mergeBB); PHINode* pn = cenv.builder.CreatePHI(type(cenv)->ctype, "iftmp"); pn->addIncoming(thenV, thenBB); pn->addIncoming(elseV, elseBB); return pn; } void ASTClosure::lift(CEnv& cenv) { assert(!func); // Can't lift a closure with variable types (lift later when called) for (size_t i = 0; i < prot->tup.size(); ++i) if (prot->tup[i]->type(cenv)->var) return; cenv.code.push(); ASTSymbol* sym = dynamic_cast(body); if (sym) { AST** obj = cenv.code.ref(sym); if (!obj) { std::cout << "UNDEFINED SYMBOL BODY\n"; prot->tup.push_back(sym); bindings.push_back(sym); } } // Write function declaration Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, body->type(cenv)->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); // Bind argument values in CEnv vector args; vector::const_iterator p = prot->tup.begin(); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) cenv.vals.def(dynamic_cast(*p), &*a); // Write function body try { Value* retVal = body->compile(cenv); cenv.builder.CreateRet(retVal); // Finish function verifyFunction(*f); // Validate generated code cenv.fpm.run(*f); // Optimize function func = f; } catch (exception e) { f->eraseFromParent(); // Error reading body, remove function throw e; } cenv.code.pop(); } Value* ASTClosure::compile(CEnv& cenv) { // Function was already compiled in the lifting pass return func; } Value* ASTPrimitive::compile(CEnv& cenv) { size_t np = 0; vector params(tup.size() - 1); vector::const_iterator a = tup.begin(); for (++a; a != tup.end(); ++a) params[np++] = (*a)->compile(cenv); switch (params.size()) { case 0: throw SyntaxError("Primitive expects at least 1 argument"); case 1: return params[0]; default: Value* val = cenv.builder.CreateBinOp(op, params[0], params[1]); for (size_t i = 2; i < params.size(); ++i) val = cenv.builder.CreateBinOp(op, val, params[i]); return val; } } /*************************************************************************** * REPL * ***************************************************************************/ int main() { PEnv penv; penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, 0)); penv.reg(penv.sym("if"), PEnv::Parser(parseIf, 0)); penv.reg(penv.sym("def"), PEnv::Parser(parseDef, 0)); penv.reg(penv.sym("+"), PEnv::Parser(parsePrim, Instruction::Add)); penv.reg(penv.sym("-"), PEnv::Parser(parsePrim, Instruction::Sub)); penv.reg(penv.sym("*"), PEnv::Parser(parsePrim, Instruction::Mul)); penv.reg(penv.sym("/"), PEnv::Parser(parsePrim, Instruction::FDiv)); penv.reg(penv.sym("%"), PEnv::Parser(parsePrim, Instruction::FRem)); penv.reg(penv.sym("&"), PEnv::Parser(parsePrim, Instruction::And)); penv.reg(penv.sym("|"), PEnv::Parser(parsePrim, Instruction::Or)); penv.reg(penv.sym("^"), PEnv::Parser(parsePrim, Instruction::Xor)); Module* module = new Module("repl"); ExecutionEngine* engine = ExecutionEngine::create(module); CEnv cenv(penv, module, engine->getTargetData()); cenv.code.def(penv.sym("true"), new ASTLiteral(true)); cenv.code.def(penv.sym("false"), new ASTLiteral(false)); while (1) { std::cout << "(=>) "; std::cout.flush(); SExp exp = readExpression(std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; try { AST* body = parseExpression(penv, exp); ASTTuple* prot = new ASTTuple(); AType* bodyT = body->type(cenv); if (!bodyT) throw TypeError("REPL call to untyped body"); if (bodyT->var) throw TypeError("REPL call to variable typed body"); body->lift(cenv); if (bodyT->ctype) { // Create anonymous function to insert code into. Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); try { Value* retVal = body->compile(cenv); cenv.builder.CreateRet(retVal); // Finish function verifyFunction(*f); // Validate generated code cenv.fpm.run(*f); // Optimize function } catch (SyntaxError e) { f->eraseFromParent(); // Error reading body, remove function throw e; } void* fp = engine->getPointerToFunction(f); double (*cfunc)() = (double (*)())fp; std::cout << cfunc(); } else { Value* val = body->compile(cenv); std::cout << val; } std::cout << " :: " << body->type(cenv)->str(cenv) << endl; } catch (SyntaxError e) { std::cerr << "Syntax error: " << e.what() << endl; } catch (TypeError e) { std::cerr << "Type error: " << e.what() << endl; } catch (CompileError e) { std::cerr << "Compile error: " << e.what() << endl; } } std::cout << "Generated code:" << endl; module->dump(); return 0; }