diff options
author | David Robillard <d@drobilla.net> | 2009-01-24 04:01:15 +0000 |
---|---|---|
committer | David Robillard <d@drobilla.net> | 2009-01-24 04:01:15 +0000 |
commit | 0a88fdb91fcf631ecae7f8b9a55970822536d9ef (patch) | |
tree | 8b285a70f23d2274d8a651e7ce7942b9ea655aef | |
parent | a49206460b255e6006be698addb7f6fac5d57c43 (diff) | |
download | resp-0a88fdb91fcf631ecae7f8b9a55970822536d9ef.tar.gz resp-0a88fdb91fcf631ecae7f8b9a55970822536d9ef.tar.bz2 resp-0a88fdb91fcf631ecae7f8b9a55970822536d9ef.zip |
Rudimentary type system.
git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@9 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r-- | ll.cpp | 453 |
1 files changed, 294 insertions, 159 deletions
@@ -35,14 +35,24 @@ #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" +#define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i) + using namespace llvm; using namespace std; +struct Error : public std::exception { + Error(const char* m) : msg(m) {} + const char* what() const throw() { return msg; } + const char* msg; +}; + /*************************************************************************** * S-Expression Lexer - Read text and output nested lists of strings * ***************************************************************************/ +struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} }; + struct SExp { SExp() : type(LIST) {} SExp(const std::list<SExp>& l) : type(LIST), list(l) {} @@ -52,12 +62,6 @@ struct SExp { std::list<SExp> list; }; -struct SyntaxError : public std::exception { - SyntaxError(const char* m) : msg(m) {} - const char* what() const throw() { return msg; } - const char* msg; -}; - static SExp readExpression(std::istream& in) { @@ -121,79 +125,112 @@ readExpression(std::istream& in) * Abstract Syntax Tree * ***************************************************************************/ +struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; + struct CEnv; ///< Compile Time Environment /// Base class for all AST nodes struct AST { virtual ~AST() {} - virtual Value* Codegen(CEnv& cenv) = 0; - virtual bool evaluatable() const { return true; } + virtual const Type* type(CEnv& cenv) const = 0; + virtual Value* compile(CEnv& cenv) = 0; }; /// Numeric literal, e.g. "1.0" struct ASTNumber : public AST { - ASTNumber(double val) : _val(val) {} - virtual Value* Codegen(CEnv& cenv); -private: - double _val; + ASTNumber(double v) : val(v) {} + virtual const Type* type(CEnv& cenv) const { return Type::DoubleTy; } + virtual Value* compile(CEnv& cenv); + const double val; }; /// Symbol, e.g. "a" struct ASTSymbol : public AST { - ASTSymbol(const string& name) : _name(name) {} - virtual Value* Codegen(CEnv& cenv); -private: - string _name; + ASTSymbol(const string& n) : name(n) {} + virtual const Type* type(CEnv& cenv) const; + virtual Value* compile(CEnv& cenv); + const string name; }; /// Function call/application, e.g. "(func arg1 arg2)" struct ASTCall : public AST { - ASTCall(const string& n, vector<AST*>& a) : _name(n), _args(a) {} - virtual Value* Codegen(CEnv& cenv); -protected: - string _name; - vector<AST*> _args; + ASTCall(AST* f, vector<AST*>& a) : func(f), args(a) {} + virtual const Type* type(CEnv& cenv) const { + const FunctionType* ftype = dynamic_cast<const FunctionType*>(func->type(cenv)); + if (!ftype) throw TypeError(string("Call to non-function type :: ") + .append(func->type(cenv)->getDescription()).c_str()); + return ftype->getReturnType(); + } + virtual Value* compile(CEnv& cenv); + AST* const func; + const vector<AST*> args; }; /// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" struct ASTDefinition : public ASTCall { - ASTDefinition(const string& n, vector<AST*> a) : ASTCall(n, a) {} - virtual Value* Codegen(CEnv& cenv); + ASTDefinition(vector<AST*> a) : ASTCall(NULL, a) {} + virtual const Type* type(CEnv& cenv) const { return args[1]->type(cenv); } + virtual Value* compile(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { - ASTIf(const string& n, vector<AST*>& a) : ASTCall(n, a) {} - virtual Value* Codegen(CEnv& cenv); + ASTIf(vector<AST*>& a) : ASTCall(NULL, a) {} + virtual const Type* type(CEnv& cenv) const { + const Type* ctype = args[0]->type(cenv); + const Type* ttype = args[1]->type(cenv); + const Type* etype = args[2]->type(cenv); + if (ctype != Type::DoubleTy) throw TypeError("If condition is not a number"); + if (ttype != etype) throw TypeError("If branches have different types"); + return ttype; + } + virtual Value* compile(CEnv& cenv); }; -/// Primitive (builtin arithmetic function) +/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { - ASTPrimitive(const string& n, vector<AST*>& a) : ASTCall(n, a) {} - virtual Value* Codegen(CEnv& cenv); + ASTPrimitive(char n, vector<AST*>& a) : ASTCall(NULL, a), name(n) {} + virtual const Type* type(CEnv& cenv) const { return Type::DoubleTy; } + virtual Value* compile(CEnv& cenv); + const char name; }; -/// Function prototype -struct ASTPrototype : public AST { - ASTPrototype(const string& n, const vector<string>& p=vector<string>()) - : _name(n), _params(p) {} - virtual bool evaluatable() const { return false; } - Value* Codegen(CEnv& cenv) { return Funcgen(cenv); } - Function* Funcgen(CEnv& cenv); -private: - string _name; - vector<string> _params; +/// Function prototype (actual LLVM IR function prototype) +struct ASTPrototype { + ASTPrototype(vector<ASTSymbol*> p=vector<ASTSymbol*>()) : params(p) {} + vector<const Type*> argsType(CEnv& cenv) { + vector<const Type*> types; + FOREACH(vector<ASTSymbol*>::const_iterator, p, params) + types.push_back((*p)->type(cenv)); + return types; + } + virtual const Type* type(CEnv& cenv) const { return NULL; } + Function* compile(CEnv& cenv, FunctionType* type, const string& name); + string name; + const vector<ASTSymbol*> params; }; -/// Function definition -struct ASTFunction : public AST { - ASTFunction(ASTPrototype* p, AST* b) : _proto(p), _body(b) {} - virtual bool evaluatable() const { return false; } - Value* Codegen(CEnv& cenv) { return Funcgen(cenv); } - Function* Funcgen(CEnv& cenv); +/// Closure (first-class function with captured lexical bindings) +struct ASTClosure : public AST { + ASTClosure(ASTPrototype* p, AST* b) : prot(p), body(b), func(0) {} + virtual const Type* type(CEnv& cenv) const { + return FunctionType::get(body->type(cenv), prot->argsType(cenv), false); + } + virtual Value* compile(CEnv& cenv); + virtual void lift(CEnv& cenv); + ASTPrototype* const prot; + AST* const body; private: - ASTPrototype* _proto; - AST* _body; + Function* func; +}; + + +/// Function definition (actual LLVM IR function) +struct ASTFunction { + ASTFunction(ASTPrototype* p, AST* b) : prot(p), body(b) {} + Function* compile(CEnv& cenv, const string& name); + ASTPrototype* const prot; + AST* const body; }; @@ -202,87 +239,118 @@ private: * Parser - Transform S-Expressions into AST nodes * ***************************************************************************/ -static AST* parseExpression(const SExp& exp); +/// Parse Time Environment +/*struct PEnv : public list< map<string, ASTSymbol*> > { + PEnv() : list< map<string, ASTSymbol*> >(1) {} + ASTSymbol* lookup(const string& name) const { + FOREACH(const_iterator, i, *this) { + map<string, ASTSymbol*>::const_iterator s = i->find(name); + if (s != i->end()) + return s->second; + } + return NULL; + } +};*/ +struct PEnv : private map<string, ASTSymbol*> { + ASTSymbol* sym(const string& name) { + const_iterator i = find(name); + if (i != end()) { + return i->second; + } else { + ASTSymbol* sym = new ASTSymbol(name); + insert(make_pair(name, sym)); + return sym; + } + } +}; + +static AST* parseExpression(PEnv& penv, const SExp& exp); /// numberexpr ::= number static AST* -parseNumber(const SExp& exp) +parseNumber(PEnv& penv, const SExp& exp) { assert(exp.type == SExp::ATOM); return new ASTNumber(strtod(exp.atom.c_str(), NULL)); } /// identifierexpr ::= identifier -static AST* -parseSymbol(const SExp& exp) +static ASTSymbol* +parseSymbol(PEnv& penv, const SExp& exp) { - assert(exp.type == SExp::ATOM); - return new ASTSymbol(exp.atom); + if (exp.type != SExp::ATOM) throw SyntaxError("Expected symbol"); + return penv.sym(exp.atom); } -/// prototype ::= (name [arg*]) +/// prototypeexpr ::= ([arg*]) static ASTPrototype* -parsePrototype(const SExp& exp) +parsePrototype(PEnv& penv, const SExp& exp) { - list<SExp>::const_iterator i = exp.list.begin(); - const string& name = i->atom; - - vector<string> args; - for (++i; i != exp.list.end(); ++i) + vector<ASTSymbol*> params; + FOREACH(list<SExp>::const_iterator, i, exp.list) if (i->type == SExp::ATOM) - args.push_back(i->atom); + params.push_back(penv.sym(i->atom)); else - throw SyntaxError("Expected parameter name, found list"); + throw SyntaxError("Expected symbol"); - return new ASTPrototype(name, args); + return new ASTPrototype(params); } /// callexpr ::= (expression [...]) static AST* -parseCall(const SExp& exp) +parseCall(PEnv& penv, const SExp& exp) { - if (exp.list.empty()) + const list<SExp>& l = exp.list; + if (l.empty()) return NULL; + + list<SExp>::const_iterator i = l.begin(); - list<SExp>::const_iterator i = exp.list.begin(); - const string& name = i->atom; + if (i->type == SExp::LIST) + throw SyntaxError("Call to list"); + + const string& name = i++->atom; + + if (name == "fn") { // Lambda special form + ASTPrototype* proto = parsePrototype(penv, *i++); + AST* body = parseExpression(penv, *i++); + return new ASTClosure(proto, body); - if (name == "def" && (++i)->type == SExp::LIST) { - ASTPrototype* proto = parsePrototype(*i++); - AST* body = parseExpression(*i++); - return new ASTFunction(proto, body); + //} else if (name == "foreign") { // Foreign special form + // return parsePrototype(penv, *i++); } + + vector<AST*> args(l.size() - 1); + for (size_t n = 0; i != l.end(); ++i, ++n) + args[n] = parseExpression(penv, *i); + + if (name == "def") { // Define special form + return new ASTDefinition(args); - vector<AST*> args; - for (++i; i != exp.list.end(); ++i) - args.push_back(parseExpression(*i)); + } else if (name == "if") { // If special form + return new ASTIf(args); - if (name.length() == 1) { - switch (name[0]) { - case '+': case '-': case '*': case '/': - case '%': case '&': case '|': case '^': - return new ASTPrimitive(name, args); + } else { // Generic application + if (name.length() == 1) { + switch (name[0]) { + case '+': case '-': case '*': case '/': + case '%': case '&': case '|': case '^': + return new ASTPrimitive(name[0], args); + } } - } else if (name == "if") { - return new ASTIf(name, args); - } else if (name == "def") { - return new ASTDefinition(name, args); - } else if (name == "foreign") { - return parsePrototype(*++i++); + return new ASTCall(parseExpression(penv, name), args); } - - return new ASTCall(name, args); } static AST* -parseExpression(const SExp& exp) +parseExpression(PEnv& penv, const SExp& exp) { if (exp.type == SExp::LIST) { - return parseCall(exp); + return parseCall(penv, exp); } else if (isalpha(exp.atom[0])) { - return parseSymbol(exp); + return parseSymbol(penv, exp); } else if (isdigit(exp.atom[0])) { - return parseNumber(exp); + return parseNumber(penv, exp); } else { throw SyntaxError("Illegal atom"); } @@ -296,7 +364,7 @@ parseExpression(const SExp& exp) /// Compile-time environment struct CEnv { CEnv(Module* m, const TargetData* target) - : module(m), provider(module), fpm(&provider), id(0) + : module(m), provider(module), fpm(&provider), symID(0) { // Set up the optimizer pipeline. // Register info about how the target lays out data structures. @@ -311,60 +379,90 @@ struct CEnv { fpm.add(createCFGSimplificationPass()); } string gensym(const char* base="_") { - ostringstream s; s << base << id++; return s.str(); + ostringstream s; s << base << symID++; return s.str(); } + typedef map<const ASTSymbol*, Value*> Vals; + typedef map<const ASTSymbol*, const Type*> Types; IRBuilder<> builder; Module* module; ExistingModuleProvider provider; FunctionPassManager fpm; - map<string, Value*> env; - size_t id; + size_t symID; + Vals vals; + Types types; }; +static void +lambdaLift(CEnv& env, AST* ast) +{ + if (ASTClosure* closure = dynamic_cast<ASTClosure*>(ast)) { + lambdaLift(env, closure->body); + closure->lift(env); + } else if (ASTCall* call = dynamic_cast<ASTCall*>(ast)) { + FOREACH(vector<AST*>::const_iterator, a, call->args) + lambdaLift(env, *a); + } +} + Value* -ASTNumber::Codegen(CEnv& cenv) +ASTNumber::compile(CEnv& cenv) +{ + return ConstantFP::get(APFloat(val)); +} + +const Type* +ASTSymbol::type(CEnv& cenv) const { - return ConstantFP::get(APFloat(_val)); + CEnv::Types::const_iterator t = cenv.types.find(this); + if (t != cenv.types.end()) { + return t->second; + } else { + //std::cerr << "WARNING: Untyped symbol: " << name << endl; + return Type::DoubleTy; + } } Value* -ASTSymbol::Codegen(CEnv& cenv) +ASTSymbol::compile(CEnv& cenv) { - map<string, Value*>::const_iterator v = cenv.env.find(_name); - if (v == cenv.env.end()) - throw SyntaxError((string("Undefined symbol '") + _name + "'").c_str()); + CEnv::Vals::const_iterator v = cenv.vals.find(this); + if (v == cenv.vals.end()) + throw SyntaxError((string("Undefined symbol '") + name + "'").c_str()); return v->second; } Value* -ASTDefinition::Codegen(CEnv& cenv) +ASTDefinition::compile(CEnv& cenv) { - map<string, Value*>::const_iterator v = cenv.env.find(_name); - if (v != cenv.env.end()) throw SyntaxError("Symbol redefinition"); - if (_args.empty()) throw SyntaxError("Empty definition"); - Value* valCode = _args[0]->Codegen(cenv); - cenv.env[_name] = valCode; - return valCode; + if (args.size() != 2) throw SyntaxError("\"def\" takes exactly 2 arguments"); + ASTSymbol* sym = dynamic_cast<ASTSymbol*>(args[0]); + if (!sym) throw SyntaxError("Definition name is not a symbol"); + CEnv::Vals::const_iterator v = cenv.vals.find(sym); + if (v != cenv.vals.end()) throw SyntaxError("Symbol redefinition"); + + Value* val = args[1]->compile(cenv); + cenv.types[sym] = args[1]->type(cenv); + cenv.vals[sym] = val; + return val; } Value* -ASTCall::Codegen(CEnv& cenv) +ASTCall::compile(CEnv& cenv) { - Function* f = cenv.module->getFunction(_name); - if (!f) throw SyntaxError("Undefined function"); - if (f->arg_size() != _args.size()) throw SyntaxError("Illegal arguments"); + Function* f = dynamic_cast<Function*>(func->compile(cenv)); + if (!f) throw SyntaxError("Call to non-function"); vector<Value*> params; - for (size_t i = 0; i != _args.size(); ++i) - params.push_back(_args[i]->Codegen(cenv)); + for (size_t i = 0; i != args.size(); ++i) + params.push_back(args[i]->compile(cenv)); return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } Value* -ASTIf::Codegen(CEnv& cenv) +ASTIf::compile(CEnv& cenv) { - Value* condV = _args[0]->Codegen(cenv); + Value* condV = args[0]->compile(cenv); // Convert condition to a bool by comparing equal to 0.0. condV = cenv.builder.CreateFCmpONE( @@ -382,76 +480,118 @@ ASTIf::Codegen(CEnv& cenv) // Emit then value. cenv.builder.SetInsertPoint(thenBB); - Value* thenV = _args[1]->Codegen(cenv); + Value* thenV = args[1]->compile(cenv); cenv.builder.CreateBr(mergeBB); - // Codegen of 'Then' can change the current block, update thenBB + // compile of 'Then' can change the current block, update thenBB thenBB = cenv.builder.GetInsertBlock(); // Emit else block. parent->getBasicBlockList().push_back(elseBB); cenv.builder.SetInsertPoint(elseBB); - Value* elseV = _args[2]->Codegen(cenv); + Value* elseV = args[2]->compile(cenv); cenv.builder.CreateBr(mergeBB); - // Codegen of 'Else' can change the current block, update elseBB + // compile of 'Else' can change the current block, update elseBB elseBB = cenv.builder.GetInsertBlock(); // Emit merge block. parent->getBasicBlockList().push_back(mergeBB); cenv.builder.SetInsertPoint(mergeBB); - PHINode* pn = cenv.builder.CreatePHI(Type::DoubleTy, "iftmp"); + PHINode* pn = cenv.builder.CreatePHI(type(cenv), "iftmp"); pn->addIncoming(thenV, thenBB); pn->addIncoming(elseV, elseBB); return pn; } -Function* -ASTPrototype::Funcgen(CEnv& cenv) +void +ASTClosure::lift(CEnv& cenv) +{ + assert(!func); + // Write function declaration + Function* f = prot->compile(cenv, + FunctionType::get(body->type(cenv), prot->argsType(cenv), false), + cenv.gensym("_fn")); + BasicBlock* bb = BasicBlock::Create("entry", f); + cenv.builder.SetInsertPoint(bb); + + // Bind argument values in CEnv + vector<Value*> args; + vector<ASTSymbol*>::const_iterator p = prot->params.begin(); + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) + cenv.vals[*p] = &*a; + + // Write function body + try { + Value* retVal = body->compile(cenv); + cenv.builder.CreateRet(retVal); // Finish function + verifyFunction(*f); // Validate generated code + cenv.fpm.run(*f); // Optimize function + func = f; + } catch (SyntaxError e) { + f->eraseFromParent(); // Error reading body, remove function + throw e; + } +} + +Value* +ASTClosure::compile(CEnv& cenv) { - // Make the function type, e.g. double(double,double) - vector<const Type*> argsT(_params.size(), Type::DoubleTy); - FunctionType* FT = FunctionType::get(Type::DoubleTy, argsT, false); + // Function was already compiled in the lifting pass + assert(func); + return func; +} - Function* f = Function::Create( - FT, Function::ExternalLinkage, _name, cenv.module); +Function* +ASTPrototype::compile(CEnv& cenv, FunctionType* FT, const std::string& n) +{ + name = n; + Function::LinkageTypes linkage = Function::ExternalLinkage; + Function* f = Function::Create(FT, linkage, name, cenv.module); // If F conflicted, there was already something named 'Name'. // If it has a body, don't allow redefinition. - if (f->getName() != _name) { + if (f->getName() != name) { // Delete the one we just made and get the existing one. f->eraseFromParent(); - f = cenv.module->getFunction(_name); + f = cenv.module->getFunction(name); // If F already has a body, reject this. if (!f->empty()) throw SyntaxError("Function redefined"); // If F took a different number of args, reject. - if (f->arg_size() != _params.size()) + if (f->arg_size() != params.size()) throw SyntaxError("Function redefined with mismatched arguments"); } Function::arg_iterator a = f->arg_begin(); - for (size_t i = 0; i != _params.size(); ++a, ++i) { - a->setName(_params[i]); // Set name in generated code - cenv.env[_params[i]] = a; // Add to environment + for (size_t i = 0; i != params.size(); ++a, ++i) { + assert(params[i]); + a->setName(params[i]->name); // Set name in generated code + //cenv.vals[params[i]] = a; // Add to environment } return f; } Function* -ASTFunction::Funcgen(CEnv& cenv) +ASTFunction::compile(CEnv& cenv, const string& name) { - Function* f = _proto->Funcgen(cenv); + const Type* bodyT = body->type(cenv); + if (dynamic_cast<const FunctionType*>(bodyT)) { + std::cout << "First class function alert" << endl; + bodyT = PointerType::get(bodyT, 0); + } + FunctionType* fT = FunctionType::get(bodyT, prot->argsType(cenv), false); + Function* f = prot->compile(cenv, fT, name); // Create a new basic block to start insertion into. BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); try { - Value* retVal = _body->Codegen(cenv); + Value* retVal = body->compile(cenv); cenv.builder.CreateRet(retVal); // Finish function verifyFunction(*f); // Validate generated code cenv.fpm.run(*f); // Optimize function @@ -460,16 +600,13 @@ ASTFunction::Funcgen(CEnv& cenv) f->eraseFromParent(); // Error reading body, remove function throw e; } - - return 0; // Never reached } Value* -ASTPrimitive::Codegen(CEnv& cenv) +ASTPrimitive::compile(CEnv& cenv) { Instruction::BinaryOps op; - assert(_name.length() == 1); - switch (_name[0]) { + switch (name) { case '+': op = Instruction::Add; break; case '-': op = Instruction::Sub; break; case '*': op = Instruction::Mul; break; @@ -482,8 +619,8 @@ ASTPrimitive::Codegen(CEnv& cenv) } vector<Value*> params; - for (vector<AST*>::const_iterator a = _args.begin(); a != _args.end(); ++a) - params.push_back((*a)->Codegen(cenv)); + FOREACH(vector<AST*>::const_iterator, a, args) + params.push_back((*a)->compile(cenv)); switch (params.size()) { case 0: @@ -507,6 +644,7 @@ ASTPrimitive::Codegen(CEnv& cenv) static void repl(CEnv& cenv, ExecutionEngine* engine) { + PEnv penv; while (1) { std::cout << "> "; std::cout.flush(); @@ -515,24 +653,21 @@ repl(CEnv& cenv, ExecutionEngine* engine) break; try { - AST* ast = parseExpression(exp); - if (!ast) - continue; - if (ast->evaluatable()) { - ASTPrototype* proto = new ASTPrototype(cenv.gensym("repl")); - ASTFunction* func = new ASTFunction(proto, ast); - Function* code = func->Funcgen(cenv); - void* fp = engine->getPointerToFunction(code); - double (*f)() = (double (*)())fp; - std::cout << f() << endl; - //code->eraseFromParent(); - } else { - Value* code = ast->Codegen(cenv); - std::cout << "Generated code:" << endl; - code->dump(); - } + AST* ast = parseExpression(penv, exp); + //const Type* type = ast->type(cenv); + lambdaLift(cenv, ast); + ASTPrototype* proto = new ASTPrototype(); + ASTFunction* func = new ASTFunction(proto, ast); + Function* code = func->compile(cenv, cenv.gensym("repl")); + void* fp = engine->getPointerToFunction(code); + double (*f)() = (double (*)())fp; + std::cout << f() << " :: "; + func->body->type(cenv)->print(std::cout); + std::cout << endl; } catch (SyntaxError e) { std::cerr << "Syntax error: " << e.what() << endl; + } catch (TypeError e) { + std::cerr << "Type error: " << e.what() << endl; } } } |