/* A Trivial LLVM LISP * Copyright (C) 2008-2009 David Robillard * * Parts from the Kaleidoscope tutorial * by Chris Lattner and Erick Tryzelaar * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with This program. If not, see . */ #include #include #include #include #include #include #include #include "llvm/Analysis/Verifier.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" #define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i) using namespace llvm; using namespace std; struct Error : public std::exception { Error(const char* m) : msg(m) {} const char* what() const throw() { return msg; } const char* msg; }; /*************************************************************************** * S-Expression Lexer :: text -> S-Expressions (SExp) * ***************************************************************************/ struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} }; struct SExp { SExp() : type(LIST) {} SExp(const std::list& l) : type(LIST), list(l) {} SExp(const std::string& s) : type(ATOM), atom(s) {} enum { ATOM, LIST } type; std::string atom; std::list list; }; static SExp readExpression(std::istream& in) { stack stk; string tok; #define APPEND_TOK() \ if (stk.empty()) return tok; else stk.top().list.push_back(SExp(tok)) while (char ch = in.get()) { switch (ch) { case EOF: return SExp(); case ' ': case '\t': case '\n': if (tok == "") continue; else APPEND_TOK(); tok = ""; break; case '"': do { tok.push_back(ch); } while ((ch = in.get()) != '"'); tok.push_back('"'); APPEND_TOK(); tok = ""; break; case '(': stk.push(SExp()); break; case ')': switch (stk.size()) { case 0: throw SyntaxError("Missing '('"); break; case 1: if (tok != "") stk.top().list.push_back(SExp(tok)); return stk.top(); default: if (tok != "") stk.top().list.push_back(SExp(tok)); SExp l = stk.top(); stk.pop(); stk.top().list.push_back(l); } tok = ""; break; default: tok.push_back(ch); } } switch (stk.size()) { case 0: return tok; break; case 1: return stk.top(); break; default: throw SyntaxError("Missing ')'"); } return SExp(); } /*************************************************************************** * Abstract Syntax Tree * ***************************************************************************/ struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; struct CEnv; ///< Compile Time Environment /// Base class for all AST nodes struct AST { virtual ~AST() {} virtual const Type* type(CEnv& cenv) const = 0; virtual Value* compile(CEnv& cenv) = 0; }; /// Literal template struct ASTLiteral : public AST { ASTLiteral(VT v) : val(v) {} const Type* type(CEnv& cenv) const; Value* compile(CEnv& cenv); const VT val; }; #define LITERAL(CT, VT, COMPILED) \ template<> const Type* \ ASTLiteral::type(CEnv& cenv) const { return VT; } \ \ template<> Value* \ ASTLiteral::compile(CEnv& cenv) { return (COMPILED); } /// Literal template specialisations LITERAL(int32_t, Type::Int32Ty, ConstantInt::get(type(cenv), val, true)); LITERAL(float, Type::FloatTy, ConstantFP::get(type(cenv), val)); LITERAL(bool, Type::Int1Ty, ConstantInt::get(type(cenv), val, false)); /// Symbol, e.g. "a" struct ASTSymbol : public AST { ASTSymbol(const string& n) : name(n) {} virtual const Type* type(CEnv& cenv) const; virtual Value* compile(CEnv& cenv); const string name; }; /// Function call/application, e.g. "(func arg1 arg2)" struct ASTCall : public AST { ASTCall(const vector& c) : code(c) {} virtual const Type* type(CEnv& cenv) const { AST* func = code[0]; const FunctionType* ftype = dynamic_cast(func->type(cenv)); if (!ftype) throw TypeError(string("Call to non-function type :: ") .append(func->type(cenv)->getDescription()).c_str()); return ftype->getReturnType(); } virtual Value* compile(CEnv& cenv); const vector code; }; /// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" struct ASTDefinition : public ASTCall { ASTDefinition(const vector& c) : ASTCall(c) {} virtual const Type* type(CEnv& cenv) const { return code[2]->type(cenv); } virtual Value* compile(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { ASTIf(const vector& c) : ASTCall(c) {} virtual const Type* type(CEnv& cenv) const { const Type* cT = code[1]->type(cenv); const Type* tT = code[2]->type(cenv); const Type* eT = code[3]->type(cenv); if (cT != Type::Int1Ty) throw TypeError("If condition is not a boolean"); if (tT != eT) throw TypeError("If branches have different types"); return tT; } virtual Value* compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { ASTPrimitive(const vector& c, Instruction::BinaryOps o) : ASTCall(c), op(o) {} virtual const Type* type(CEnv& cenv) const { return Type::FloatTy; } virtual Value* compile(CEnv& cenv); Instruction::BinaryOps op; }; /// Function prototype (actual LLVM IR function prototype) struct ASTPrototype { ASTPrototype(vector p=vector()) : params(p) {} vector argsType(CEnv& cenv) { vector types; FOREACH(vector::const_iterator, p, params) types.push_back((*p)->type(cenv)); return types; } virtual const Type* type(CEnv& cenv) const { return NULL; } Function* compile(CEnv& cenv, FunctionType* type, const string& name); string name; vector params; }; /// Closure (first-class function with captured lexical bindings) struct ASTClosure : public AST { ASTClosure(ASTPrototype* p, AST* b) : prot(p), body(b), func(0) {} virtual const Type* type(CEnv& cenv) const { return FunctionType::get(body->type(cenv), prot->argsType(cenv), false); } virtual Value* compile(CEnv& cenv); virtual void lift(CEnv& cenv); ASTPrototype* const prot; AST* const body; vector bindings; private: Function* func; }; /// Function definition (actual LLVM IR function) struct ASTFunction { ASTFunction(ASTPrototype* p, AST* b) : prot(p), body(b) {} Function* compile(CEnv& cenv, const string& name); ASTPrototype* const prot; AST* const body; }; /*************************************************************************** * Parser - S-Expressions (SExp) -> AST Nodes (AST) * ***************************************************************************/ typedef unsigned UD; // User Data passed to registered parse functions // Parse Time Environment (just a symbol table) struct PEnv : private map { typedef AST* (*PF)(PEnv&, const list&, UD); // Parse Function struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; map parsers; void reg(const ASTSymbol* s, const Parser& p) { parsers.insert(make_pair(s, p)); } const Parser* parser(const ASTSymbol* s) const { map::const_iterator i = parsers.find(s); return (i != parsers.end()) ? &i->second : NULL; } ASTSymbol* sym(const string& s) { const const_iterator i = find(s); return ((i != end()) ? i->second : insert(make_pair(s, new ASTSymbol(s))).first->second); } }; /// The fundamental parser method static AST* parseExpression(PEnv& penv, const SExp& exp) { if (exp.type == SExp::LIST) { // Parse head of list if (exp.list.empty()) throw SyntaxError("Call to empty list"); vector code(exp.list.size()); code[0] = parseExpression(penv, exp.list.front()); // Dispatch to parse function if possible ASTSymbol* sym = dynamic_cast(code[0]); const PEnv::Parser* handler = sym ? penv.parser(sym) : NULL; if (handler) return handler->pf(penv, exp.list, handler->ud); // Parse as a regular call list::const_iterator i = exp.list.begin(); ++i; for (size_t n = 1; i != exp.list.end(); ++i) code[n++] = parseExpression(penv, *i); return new ASTCall(code); } else if (isdigit(exp.atom[0])) { if (exp.atom.find('.') == string::npos) return new ASTLiteral(strtol(exp.atom.c_str(), NULL, 10)); else return new ASTLiteral(strtod(exp.atom.c_str(), NULL)); } return penv.sym(exp.atom); } // Special forms static vector pmap(PEnv& penv, const list& l) { vector code(l.size()); size_t n = 0; for (list::const_iterator i = l.begin(); i != l.end(); ++i) code[n++] = parseExpression(penv, *i); return code; } static AST* parseIf(PEnv& penv, const list& c, UD) { return new ASTIf(pmap(penv, c)); } static AST* parseDef(PEnv& penv, const list& c, UD) { return new ASTDefinition(pmap(penv, c)); } static AST* parsePrim(PEnv& penv, const list& c, UD data) { return new ASTPrimitive(pmap(penv, c), (Instruction::BinaryOps)data); } static ASTPrototype* parsePrototype(PEnv& penv, const SExp& e, UD) { return new ASTPrototype(pmap(penv, e.list)); } static AST* parseFn(PEnv& penv, const list& c, UD) { list::const_iterator a = c.begin(); ++a; return new ASTClosure( parsePrototype(penv, *a++, 0), parseExpression(penv, *a++)); } /*************************************************************************** * Code Generation * ***************************************************************************/ /// Generic Recursive Environment (stack of key:value dictionaries) template struct Env : public list< map > { Env() : list< map >(1) {} void push() { this->push_front(map()); } void push(const map& frame) { this->push_front(frame); } map& pop() { map& front = this->front(); this->pop_front(); return front; } void def(const K& k, const V& v) { if (this->front().find(k) != this->front().end()) std::cerr << "WARNING: Redefinition: " << k << endl; this->front()[k] = v; } V* ref(const K& name) { typename Env::iterator i = this->begin(); for (; i != this->end(); ++i) { typename map::iterator s = i->find(name); if (s != i->end()) return &s->second; } return 0; } }; /// Compile-time environment struct CEnv { CEnv(Module* m, const TargetData* target) : module(m), provider(module), fpm(&provider), symID(0) { // Set up the optimizer pipeline. // Register info about how the target lays out data structures. fpm.add(new TargetData(*target)); // Do simple "peephole" and bit-twiddling optimizations. fpm.add(createInstructionCombiningPass()); // Reassociate expressions. fpm.add(createReassociatePass()); // Eliminate Common SubExpressions. fpm.add(createGVNPass()); // Simplify control flow graph (delete unreachable blocks, etc). fpm.add(createCFGSimplificationPass()); } string gensym(const char* base="_") { ostringstream s; s << base << symID++; return s.str(); } void def(const ASTSymbol* sym, AST* expr) { types.def(sym, expr->type(*this)); code.def(sym, expr); } typedef Env Types; typedef Env Code; typedef Env Vals; IRBuilder<> builder; Module* module; ExistingModuleProvider provider; FunctionPassManager fpm; size_t symID; Types types; Code code; Vals vals; }; static void lambdaLift(CEnv& env, AST* ast) { if (ASTClosure* closure = dynamic_cast(ast)) { lambdaLift(env, closure->body); closure->lift(env); } else if (ASTCall* call = dynamic_cast(ast)) { FOREACH(vector::const_iterator, a, call->code) lambdaLift(env, *a); } } const Type* ASTSymbol::type(CEnv& cenv) const { const Type** t = cenv.types.ref(this); if (t) { return *t; } else { //std::cerr << "WARNING: Untyped symbol: " << name << endl; return Type::FloatTy; } } Value* ASTSymbol::compile(CEnv& cenv) { Value*const* v = cenv.vals.ref(this); if (v) return *v; AST*const* c = cenv.code.ref(this); if (c) { Value* v = (*c)->compile(cenv); cenv.vals.def(this, v); return v; } throw SyntaxError((string("Undefined symbol '") + name + "'").c_str()); } Value* ASTDefinition::compile(CEnv& cenv) { if (code.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments"); const ASTSymbol* sym = dynamic_cast(code[1]); if (!sym) throw SyntaxError("Definition name is not a symbol"); Value* val = code[2]->compile(cenv); cenv.types.def(sym, code[2]->type(cenv)); cenv.code.def(sym, code[2]); cenv.vals.def(sym, val); return val; } Value* ASTCall::compile(CEnv& cenv) { AST* func = code[0]; AST** closure = cenv.code.ref((ASTSymbol*)func); assert(closure); ASTClosure* c = dynamic_cast(*closure); assert(c); Function* f = dynamic_cast(func->compile(cenv)); if (!f) throw SyntaxError("Call to non-function"); vector params; for (size_t i = 1; i < code.size(); ++i) params.push_back(code[i]->compile(cenv)); for (size_t i = 0; i < c->bindings.size(); ++i) std::cout << "BINDING: " << c->bindings[i]->name << endl; return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } Value* ASTIf::compile(CEnv& cenv) { Value* condV = code[1]->compile(cenv); Function* parent = cenv.builder.GetInsertBlock()->getParent(); // Create blocks for the then and else cases. // Insert the 'then' block at the end of the function. BasicBlock* thenBB = BasicBlock::Create("then", parent); BasicBlock* elseBB = BasicBlock::Create("else"); BasicBlock* mergeBB = BasicBlock::Create("ifcont"); cenv.builder.CreateCondBr(condV, thenBB, elseBB); // Emit then value. cenv.builder.SetInsertPoint(thenBB); Value* thenV = code[2]->compile(cenv); cenv.builder.CreateBr(mergeBB); // compile of 'Then' can change the current block, update thenBB thenBB = cenv.builder.GetInsertBlock(); // Emit else block. parent->getBasicBlockList().push_back(elseBB); cenv.builder.SetInsertPoint(elseBB); Value* elseV = code[3]->compile(cenv); cenv.builder.CreateBr(mergeBB); // compile of 'Else' can change the current block, update elseBB elseBB = cenv.builder.GetInsertBlock(); // Emit merge block. parent->getBasicBlockList().push_back(mergeBB); cenv.builder.SetInsertPoint(mergeBB); PHINode* pn = cenv.builder.CreatePHI(type(cenv), "iftmp"); pn->addIncoming(thenV, thenBB); pn->addIncoming(elseV, elseBB); return pn; } void ASTClosure::lift(CEnv& cenv) { assert(!func); //set unbound; const ASTCall* call = dynamic_cast(body); if (call) { std::cout << "LIFT CALL BODY\n"; } #if 0 Env paramsEnv; for (vector::const_iterator p = prot->params.begin(); p != prot->params.end(); ++p) { //paramsEnv[*p] = NULL; std::cout << "PARAM: " << (*p)->name << endl; } #endif cenv.code.push(); ASTSymbol* sym = dynamic_cast(body); if (sym) { AST** obj = cenv.code.ref(sym); if (!obj) { std::cout << "UNDEFINED SYMBOL BODY\n"; prot->params.push_back(sym); bindings.push_back(sym); } } // Write function declaration Function* f = prot->compile(cenv, FunctionType::get(body->type(cenv), prot->argsType(cenv), false), cenv.gensym("_fn")); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); // Bind argument values in CEnv vector args; vector::const_iterator p = prot->params.begin(); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) cenv.vals.def(dynamic_cast(*p), &*a); // Write function body try { Value* retVal = body->compile(cenv); cenv.builder.CreateRet(retVal); // Finish function verifyFunction(*f); // Validate generated code cenv.fpm.run(*f); // Optimize function func = f; } catch (SyntaxError e) { f->eraseFromParent(); // Error reading body, remove function throw e; } cenv.code.pop(); } Value* ASTClosure::compile(CEnv& cenv) { // Function was already compiled in the lifting pass assert(func); return func; } Function* ASTPrototype::compile(CEnv& cenv, FunctionType* FT, const std::string& n) { name = n; Function::LinkageTypes linkage = Function::ExternalLinkage; Function* f = Function::Create(FT, linkage, name, cenv.module); // If F conflicted, there was already something named 'Name'. // If it has a body, don't allow redefinition. if (f->getName() != name) { // Delete the one we just made and get the existing one. f->eraseFromParent(); f = cenv.module->getFunction(name); // If F already has a body, reject this. if (!f->empty()) throw SyntaxError("Function redefined"); // If F took a different number of args, reject. if (f->arg_size() != params.size()) throw SyntaxError("Function redefined with mismatched arguments"); } // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); for (size_t i = 0; i != params.size(); ++a, ++i) { assert(params[i]); ASTSymbol* sym = dynamic_cast(params[i]); a->setName(sym ? sym->name : cenv.gensym("_a")); } return f; } Function* ASTFunction::compile(CEnv& cenv, const string& name) { const Type* bodyT = body->type(cenv); if (dynamic_cast(bodyT)) { std::cout << "First class function alert" << endl; bodyT = PointerType::get(bodyT, 0); } FunctionType* fT = FunctionType::get(bodyT, prot->argsType(cenv), false); Function* f = prot->compile(cenv, fT, name); // Create a new basic block to start insertion into. BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); try { Value* retVal = body->compile(cenv); cenv.builder.CreateRet(retVal); // Finish function verifyFunction(*f); // Validate generated code cenv.fpm.run(*f); // Optimize function return f; } catch (SyntaxError e) { f->eraseFromParent(); // Error reading body, remove function throw e; } } Value* ASTPrimitive::compile(CEnv& cenv) { size_t np = 0; vector params(code.size() - 1); vector::const_iterator a = code.begin(); for (++a; a != code.end(); ++a) params[np++] = (*a)->compile(cenv); switch (params.size()) { case 0: throw SyntaxError("Primitive expects at least 1 argument"); case 1: return params[0]; default: Value* val = cenv.builder.CreateBinOp(op, params[0], params[1]); for (size_t i = 2; i < params.size(); ++i) val = cenv.builder.CreateBinOp(op, val, params[i]); return val; } } /*************************************************************************** * REPL * ***************************************************************************/ int main() { Module* module = new Module("interactive"); ExecutionEngine* engine = ExecutionEngine::create(module); CEnv cenv(module, engine->getTargetData()); PEnv penv; penv.reg(penv.sym("if"), PEnv::Parser(parseIf, 0)); penv.reg(penv.sym("def"), PEnv::Parser(parseDef, 0)); penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, 0)); penv.reg(penv.sym("+"), PEnv::Parser(parsePrim, Instruction::Add)); penv.reg(penv.sym("-"), PEnv::Parser(parsePrim, Instruction::Sub)); penv.reg(penv.sym("*"), PEnv::Parser(parsePrim, Instruction::Mul)); penv.reg(penv.sym("/"), PEnv::Parser(parsePrim, Instruction::FDiv)); penv.reg(penv.sym("%"), PEnv::Parser(parsePrim, Instruction::FRem)); penv.reg(penv.sym("&"), PEnv::Parser(parsePrim, Instruction::And)); penv.reg(penv.sym("|"), PEnv::Parser(parsePrim, Instruction::Or)); penv.reg(penv.sym("^"), PEnv::Parser(parsePrim, Instruction::Xor)); cenv.def(penv.sym("true"), new ASTLiteral(true)); cenv.def(penv.sym("false"), new ASTLiteral(false)); while (1) { std::cout << "(=>) "; std::cout.flush(); SExp exp = readExpression(std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; try { AST* ast = parseExpression(penv, exp); lambdaLift(cenv, ast); ASTPrototype* proto = new ASTPrototype(); ASTFunction* func = new ASTFunction(proto, ast); Function* code = func->compile(cenv, cenv.gensym("_repl")); void* fp = engine->getPointerToFunction(code); code->dump(); double (*f)() = (double (*)())fp; std::cout << f() << " :: "; func->body->type(cenv)->print(std::cout); std::cout << endl; } catch (SyntaxError e) { std::cerr << "Syntax error: " << e.what() << endl; } catch (TypeError e) { std::cerr << "Type error: " << e.what() << endl; } } std::cout << "Generated code:" << endl; module->dump(); return 0; }