diff options
Diffstat (limited to 'll.cpp')
-rw-r--r-- | ll.cpp | 431 |
1 files changed, 245 insertions, 186 deletions
@@ -48,7 +48,7 @@ struct Error : public std::exception { /*************************************************************************** - * S-Expression Lexer - Read text and output nested lists of strings * + * S-Expression Lexer :: text -> S-Expressions (SExp) * ***************************************************************************/ struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} }; @@ -117,6 +117,7 @@ readExpression(std::istream& in) case 1: return stk.top(); break; default: throw SyntaxError("Missing ')'"); } + return SExp(); } @@ -133,17 +134,30 @@ struct CEnv; ///< Compile Time Environment struct AST { virtual ~AST() {} virtual const Type* type(CEnv& cenv) const = 0; - virtual Value* compile(CEnv& cenv) = 0; + virtual Value* compile(CEnv& cenv) = 0; }; -/// Numeric literal, e.g. "1.0" -struct ASTNumber : public AST { - ASTNumber(double v) : val(v) {} - virtual const Type* type(CEnv& cenv) const { return Type::DoubleTy; } - virtual Value* compile(CEnv& cenv); - const double val; +/// Literal +template<typename VT> +struct ASTLiteral : public AST { + ASTLiteral(VT v) : val(v) {} + const Type* type(CEnv& cenv) const; + Value* compile(CEnv& cenv); + const VT val; }; +#define LITERAL(CT, VT, COMPILED) \ +template<> const Type* \ +ASTLiteral<CT>::type(CEnv& cenv) const { return VT; } \ + \ +template<> Value* \ +ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } + +/// Literal template specialisations +LITERAL(int32_t, Type::Int32Ty, ConstantInt::get(type(cenv), val, true)); +LITERAL(float, Type::FloatTy, ConstantFP::get(type(cenv), val)); +LITERAL(bool, Type::Int1Ty, ConstantInt::get(type(cenv), val, false)); + /// Symbol, e.g. "a" struct ASTSymbol : public AST { ASTSymbol(const string& n) : name(n) {} @@ -154,60 +168,60 @@ struct ASTSymbol : public AST { /// Function call/application, e.g. "(func arg1 arg2)" struct ASTCall : public AST { - ASTCall(AST* f, vector<AST*>& a) : func(f), args(a) {} + ASTCall(const vector<AST*>& c) : code(c) {} virtual const Type* type(CEnv& cenv) const { + AST* func = code[0]; const FunctionType* ftype = dynamic_cast<const FunctionType*>(func->type(cenv)); if (!ftype) throw TypeError(string("Call to non-function type :: ") .append(func->type(cenv)->getDescription()).c_str()); return ftype->getReturnType(); } virtual Value* compile(CEnv& cenv); - AST* const func; - const vector<AST*> args; + const vector<AST*> code; }; /// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" struct ASTDefinition : public ASTCall { - ASTDefinition(vector<AST*> a) : ASTCall(NULL, a) {} - virtual const Type* type(CEnv& cenv) const { return args[1]->type(cenv); } + ASTDefinition(const vector<AST*>& c) : ASTCall(c) {} + virtual const Type* type(CEnv& cenv) const { return code[2]->type(cenv); } virtual Value* compile(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { - ASTIf(vector<AST*>& a) : ASTCall(NULL, a) {} + ASTIf(const vector<AST*>& c) : ASTCall(c) {} virtual const Type* type(CEnv& cenv) const { - const Type* ctype = args[0]->type(cenv); - const Type* ttype = args[1]->type(cenv); - const Type* etype = args[2]->type(cenv); - if (ctype != Type::DoubleTy) throw TypeError("If condition is not a number"); - if (ttype != etype) throw TypeError("If branches have different types"); - return ttype; + const Type* cT = code[1]->type(cenv); + const Type* tT = code[2]->type(cenv); + const Type* eT = code[3]->type(cenv); + if (cT != Type::Int1Ty) throw TypeError("If condition is not a boolean"); + if (tT != eT) throw TypeError("If branches have different types"); + return tT; } virtual Value* compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { - ASTPrimitive(char n, vector<AST*>& a) : ASTCall(NULL, a), name(n) {} - virtual const Type* type(CEnv& cenv) const { return Type::DoubleTy; } + ASTPrimitive(const vector<AST*>& c, Instruction::BinaryOps o) : ASTCall(c), op(o) {} + virtual const Type* type(CEnv& cenv) const { return Type::FloatTy; } virtual Value* compile(CEnv& cenv); - const char name; + Instruction::BinaryOps op; }; /// Function prototype (actual LLVM IR function prototype) struct ASTPrototype { - ASTPrototype(vector<ASTSymbol*> p=vector<ASTSymbol*>()) : params(p) {} + ASTPrototype(vector<AST*> p=vector<AST*>()) : params(p) {} vector<const Type*> argsType(CEnv& cenv) { vector<const Type*> types; - FOREACH(vector<ASTSymbol*>::const_iterator, p, params) + FOREACH(vector<AST*>::const_iterator, p, params) types.push_back((*p)->type(cenv)); return types; } virtual const Type* type(CEnv& cenv) const { return NULL; } Function* compile(CEnv& cenv, FunctionType* type, const string& name); string name; - const vector<ASTSymbol*> params; + vector<AST*> params; }; /// Closure (first-class function with captured lexical bindings) @@ -218,13 +232,13 @@ struct ASTClosure : public AST { } virtual Value* compile(CEnv& cenv); virtual void lift(CEnv& cenv); - ASTPrototype* const prot; - AST* const body; + ASTPrototype* const prot; + AST* const body; + vector<const ASTSymbol*> bindings; private: Function* func; }; - /// Function definition (actual LLVM IR function) struct ASTFunction { ASTFunction(ASTPrototype* p, AST* b) : prot(p), body(b) {} @@ -236,124 +250,97 @@ struct ASTFunction { /*************************************************************************** - * Parser - Transform S-Expressions into AST nodes * + * Parser - S-Expressions (SExp) -> AST Nodes (AST) * ***************************************************************************/ -/// Parse Time Environment -/*struct PEnv : public list< map<string, ASTSymbol*> > { - PEnv() : list< map<string, ASTSymbol*> >(1) {} - ASTSymbol* lookup(const string& name) const { - FOREACH(const_iterator, i, *this) { - map<string, ASTSymbol*>::const_iterator s = i->find(name); - if (s != i->end()) - return s->second; - } - return NULL; +typedef unsigned UD; // User Data passed to registered parse functions + +// Parse Time Environment (just a symbol table) +struct PEnv : private map<const string, ASTSymbol*> { + typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function + struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; + map<const ASTSymbol*, Parser> parsers; + void reg(const ASTSymbol* s, const Parser& p) { + parsers.insert(make_pair(s, p)); } -};*/ -struct PEnv : private map<string, ASTSymbol*> { - ASTSymbol* sym(const string& name) { - const_iterator i = find(name); - if (i != end()) { - return i->second; - } else { - ASTSymbol* sym = new ASTSymbol(name); - insert(make_pair(name, sym)); - return sym; - } + const Parser* parser(const ASTSymbol* s) const { + map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s); + return (i != parsers.end()) ? &i->second : NULL; + } + ASTSymbol* sym(const string& s) { + const const_iterator i = find(s); + return ((i != end()) + ? i->second + : insert(make_pair(s, new ASTSymbol(s))).first->second); } }; -static AST* parseExpression(PEnv& penv, const SExp& exp); - -/// numberexpr ::= number +/// The fundamental parser method static AST* -parseNumber(PEnv& penv, const SExp& exp) +parseExpression(PEnv& penv, const SExp& exp) { - assert(exp.type == SExp::ATOM); - return new ASTNumber(strtod(exp.atom.c_str(), NULL)); -} + if (exp.type == SExp::LIST) { + // Parse head of list + if (exp.list.empty()) throw SyntaxError("Call to empty list"); + vector<AST*> code(exp.list.size()); + code[0] = parseExpression(penv, exp.list.front()); + + // Dispatch to parse function if possible + ASTSymbol* sym = dynamic_cast<ASTSymbol*>(code[0]); + const PEnv::Parser* handler = sym ? penv.parser(sym) : NULL; + if (handler) + return handler->pf(penv, exp.list, handler->ud); + + // Parse as a regular call + list<SExp>::const_iterator i = exp.list.begin(); ++i; + for (size_t n = 1; i != exp.list.end(); ++i) + code[n++] = parseExpression(penv, *i); + return new ASTCall(code); -/// identifierexpr ::= identifier -static ASTSymbol* -parseSymbol(PEnv& penv, const SExp& exp) -{ - if (exp.type != SExp::ATOM) throw SyntaxError("Expected symbol"); + } else if (isdigit(exp.atom[0])) { + if (exp.atom.find('.') == string::npos) + return new ASTLiteral<int32_t>(strtol(exp.atom.c_str(), NULL, 10)); + else + return new ASTLiteral<float>(strtod(exp.atom.c_str(), NULL)); + } return penv.sym(exp.atom); } -/// prototypeexpr ::= ([arg*]) -static ASTPrototype* -parsePrototype(PEnv& penv, const SExp& exp) -{ - vector<ASTSymbol*> params; - FOREACH(list<SExp>::const_iterator, i, exp.list) - if (i->type == SExp::ATOM) - params.push_back(penv.sym(i->atom)); - else - throw SyntaxError("Expected symbol"); - - return new ASTPrototype(params); -} +// Special forms -/// callexpr ::= (expression [...]) -static AST* -parseCall(PEnv& penv, const SExp& exp) +static vector<AST*> +pmap(PEnv& penv, const list<SExp>& l) { - const list<SExp>& l = exp.list; - if (l.empty()) - return NULL; - list<SExp>::const_iterator i = l.begin(); + vector<AST*> code(l.size()); + for (size_t n = 0; i != l.end(); ++i) + code[n++] = parseExpression(penv, *i); + return code; +} - if (i->type == SExp::LIST) - throw SyntaxError("Call to list"); +static AST* +parseIf(PEnv& penv, const list<SExp>& c, UD) + { return new ASTIf(pmap(penv, c)); } - const string& name = i++->atom; +static AST* +parseDef(PEnv& penv, const list<SExp>& c, UD) + { return new ASTDefinition(pmap(penv, c)); } - if (name == "fn") { // Lambda special form - ASTPrototype* proto = parsePrototype(penv, *i++); - AST* body = parseExpression(penv, *i++); - return new ASTClosure(proto, body); - - //} else if (name == "foreign") { // Foreign special form - // return parsePrototype(penv, *i++); - } - - vector<AST*> args(l.size() - 1); - for (size_t n = 0; i != l.end(); ++i, ++n) - args[n] = parseExpression(penv, *i); - - if (name == "def") { // Define special form - return new ASTDefinition(args); +static AST* +parsePrim(PEnv& penv, const list<SExp>& c, UD data) + { return new ASTPrimitive(pmap(penv, c), (Instruction::BinaryOps)data); } - } else if (name == "if") { // If special form - return new ASTIf(args); - - } else { // Generic application - if (name.length() == 1) { - switch (name[0]) { - case '+': case '-': case '*': case '/': - case '%': case '&': case '|': case '^': - return new ASTPrimitive(name[0], args); - } - } - return new ASTCall(parseExpression(penv, name), args); - } -} +static ASTPrototype* +parsePrototype(PEnv& penv, const SExp& e, UD) + { return new ASTPrototype(pmap(penv, e.list)); } static AST* -parseExpression(PEnv& penv, const SExp& exp) +parseFn(PEnv& penv, const list<SExp>& c, UD) { - if (exp.type == SExp::LIST) { - return parseCall(penv, exp); - } else if (isalpha(exp.atom[0])) { - return parseSymbol(penv, exp); - } else if (isdigit(exp.atom[0])) { - return parseNumber(penv, exp); - } else { - throw SyntaxError("Illegal atom"); - } + list<SExp>::const_iterator a = c.begin(); ++a; + return new ASTClosure( + parsePrototype(penv, *a++, 0), + parseExpression(penv, *a++)); } @@ -361,6 +348,33 @@ parseExpression(PEnv& penv, const SExp& exp) * Code Generation * ***************************************************************************/ +/// Generic Recursive Environment (stack of key:value dictionaries) +template<typename K, typename V> +struct Env : public list< map<K,V> > { + Env() : list< map<K, V> >(1) {} + void push() { this->push_front(map<K,V>()); } + void push(const map<K,V>& frame) { this->push_front(frame); } + map<K,V>& pop() { + map<K,V>& front = this->front(); + this->pop_front(); + return front; + } + void def(const K& k, const V& v) { + if (this->front().find(k) != this->front().end()) + std::cerr << "WARNING: Redefinition: " << k << endl; + this->front()[k] = v; + } + V* ref(const K& name) { + typename Env::iterator i = this->begin(); + for (; i != this->end(); ++i) { + typename map<K,V>::iterator s = i->find(name); + if (s != i->end()) + return &s->second; + } + return 0; + } +}; + /// Compile-time environment struct CEnv { CEnv(Module* m, const TargetData* target) @@ -381,15 +395,22 @@ struct CEnv { string gensym(const char* base="_") { ostringstream s; s << base << symID++; return s.str(); } - typedef map<const ASTSymbol*, Value*> Vals; - typedef map<const ASTSymbol*, const Type*> Types; + void def(const ASTSymbol* sym, AST* expr) { + types.def(sym, expr->type(*this)); + code.def(sym, expr); + } + typedef Env<const ASTSymbol*, const Type*> Types; + typedef Env<const ASTSymbol*, AST*> Code; + typedef Env<const ASTSymbol*, Value*> Vals; + IRBuilder<> builder; Module* module; ExistingModuleProvider provider; FunctionPassManager fpm; size_t symID; - Vals vals; Types types; + Code code; + Vals vals; }; static void @@ -399,62 +420,71 @@ lambdaLift(CEnv& env, AST* ast) lambdaLift(env, closure->body); closure->lift(env); } else if (ASTCall* call = dynamic_cast<ASTCall*>(ast)) { - FOREACH(vector<AST*>::const_iterator, a, call->args) + FOREACH(vector<AST*>::const_iterator, a, call->code) lambdaLift(env, *a); } } -Value* -ASTNumber::compile(CEnv& cenv) -{ - return ConstantFP::get(APFloat(val)); -} - const Type* ASTSymbol::type(CEnv& cenv) const { - CEnv::Types::const_iterator t = cenv.types.find(this); - if (t != cenv.types.end()) { - return t->second; + const Type** t = cenv.types.ref(this); + if (t) { + return *t; } else { //std::cerr << "WARNING: Untyped symbol: " << name << endl; - return Type::DoubleTy; + return Type::FloatTy; } } Value* ASTSymbol::compile(CEnv& cenv) { - CEnv::Vals::const_iterator v = cenv.vals.find(this); - if (v == cenv.vals.end()) - throw SyntaxError((string("Undefined symbol '") + name + "'").c_str()); - return v->second; + Value*const* v = cenv.vals.ref(this); + if (v) + return *v; + + AST*const* c = cenv.code.ref(this); + if (c) { + Value* v = (*c)->compile(cenv); + cenv.vals.def(this, v); + return v; + } + + throw SyntaxError((string("Undefined symbol '") + name + "'").c_str()); } Value* ASTDefinition::compile(CEnv& cenv) { - if (args.size() != 2) throw SyntaxError("\"def\" takes exactly 2 arguments"); - ASTSymbol* sym = dynamic_cast<ASTSymbol*>(args[0]); + if (code.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments"); + const ASTSymbol* sym = dynamic_cast<const ASTSymbol*>(code[1]); if (!sym) throw SyntaxError("Definition name is not a symbol"); - CEnv::Vals::const_iterator v = cenv.vals.find(sym); - if (v != cenv.vals.end()) throw SyntaxError("Symbol redefinition"); - Value* val = args[1]->compile(cenv); - cenv.types[sym] = args[1]->type(cenv); - cenv.vals[sym] = val; + Value* val = code[2]->compile(cenv); + cenv.types.def(sym, code[2]->type(cenv)); + cenv.code.def(sym, code[2]); + cenv.vals.def(sym, val); return val; } Value* ASTCall::compile(CEnv& cenv) { + AST* func = code[0]; + AST** closure = cenv.code.ref((ASTSymbol*)func); + assert(closure); + ASTClosure* c = dynamic_cast<ASTClosure*>(*closure); + assert(c); Function* f = dynamic_cast<Function*>(func->compile(cenv)); if (!f) throw SyntaxError("Call to non-function"); vector<Value*> params; - for (size_t i = 0; i != args.size(); ++i) - params.push_back(args[i]->compile(cenv)); + for (size_t i = 1; i < code.size(); ++i) + params.push_back(code[i]->compile(cenv)); + + for (size_t i = 0; i < c->bindings.size(); ++i) + std::cout << "BINDING: " << c->bindings[i]->name << endl; return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp"); } @@ -462,25 +492,20 @@ ASTCall::compile(CEnv& cenv) Value* ASTIf::compile(CEnv& cenv) { - Value* condV = args[0]->compile(cenv); - - // Convert condition to a bool by comparing equal to 0.0. - condV = cenv.builder.CreateFCmpONE( - condV, ConstantFP::get(APFloat(0.0)), "ifcond"); - + Value* condV = code[1]->compile(cenv); Function* parent = cenv.builder.GetInsertBlock()->getParent(); // Create blocks for the then and else cases. // Insert the 'then' block at the end of the function. - BasicBlock* thenBB = BasicBlock::Create("then", parent); - BasicBlock* elseBB = BasicBlock::Create("else"); + BasicBlock* thenBB = BasicBlock::Create("then", parent); + BasicBlock* elseBB = BasicBlock::Create("else"); BasicBlock* mergeBB = BasicBlock::Create("ifcont"); cenv.builder.CreateCondBr(condV, thenBB, elseBB); // Emit then value. cenv.builder.SetInsertPoint(thenBB); - Value* thenV = args[1]->compile(cenv); + Value* thenV = code[2]->compile(cenv); cenv.builder.CreateBr(mergeBB); // compile of 'Then' can change the current block, update thenBB @@ -489,7 +514,7 @@ ASTIf::compile(CEnv& cenv) // Emit else block. parent->getBasicBlockList().push_back(elseBB); cenv.builder.SetInsertPoint(elseBB); - Value* elseV = args[2]->compile(cenv); + Value* elseV = code[3]->compile(cenv); cenv.builder.CreateBr(mergeBB); // compile of 'Else' can change the current block, update elseBB @@ -509,6 +534,34 @@ void ASTClosure::lift(CEnv& cenv) { assert(!func); + + //set<const ASTSymbol*> unbound; + const ASTCall* call = dynamic_cast<const ASTCall*>(body); + if (call) { + std::cout << "LIFT CALL BODY\n"; + } + +#if 0 + Env<const ASTSymbol*, const AST*> paramsEnv; + for (vector<const ASTSymbol*>::const_iterator p = prot->params.begin(); + p != prot->params.end(); ++p) { + //paramsEnv[*p] = NULL; + std::cout << "PARAM: " << (*p)->name << endl; + } +#endif + + cenv.code.push(); + + ASTSymbol* sym = dynamic_cast<ASTSymbol*>(body); + if (sym) { + AST** obj = cenv.code.ref(sym); + if (!obj) { + std::cout << "UNDEFINED SYMBOL BODY\n"; + prot->params.push_back(sym); + bindings.push_back(sym); + } + } + // Write function declaration Function* f = prot->compile(cenv, FunctionType::get(body->type(cenv), prot->argsType(cenv), false), @@ -518,9 +571,9 @@ ASTClosure::lift(CEnv& cenv) // Bind argument values in CEnv vector<Value*> args; - vector<ASTSymbol*>::const_iterator p = prot->params.begin(); + vector<AST*>::const_iterator p = prot->params.begin(); for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) - cenv.vals[*p] = &*a; + cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a); // Write function body try { @@ -533,6 +586,8 @@ ASTClosure::lift(CEnv& cenv) f->eraseFromParent(); // Error reading body, remove function throw e; } + + cenv.code.pop(); } Value* @@ -565,11 +620,12 @@ ASTPrototype::compile(CEnv& cenv, FunctionType* FT, const std::string& n) throw SyntaxError("Function redefined with mismatched arguments"); } + // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); for (size_t i = 0; i != params.size(); ++a, ++i) { assert(params[i]); - a->setName(params[i]->name); // Set name in generated code - //cenv.vals[params[i]] = a; // Add to environment + ASTSymbol* sym = dynamic_cast<ASTSymbol*>(params[i]); + a->setName(sym ? sym->name : cenv.gensym("_a")); } return f; @@ -605,22 +661,11 @@ ASTFunction::compile(CEnv& cenv, const string& name) Value* ASTPrimitive::compile(CEnv& cenv) { - Instruction::BinaryOps op; - switch (name) { - case '+': op = Instruction::Add; break; - case '-': op = Instruction::Sub; break; - case '*': op = Instruction::Mul; break; - case '/': op = Instruction::FDiv; break; - case '%': op = Instruction::FRem; break; - case '&': op = Instruction::And; break; - case '|': op = Instruction::Or; break; - case '^': op = Instruction::Xor; break; - default: throw SyntaxError("Unknown primitive"); - } - - vector<Value*> params; - FOREACH(vector<AST*>::const_iterator, a, args) - params.push_back((*a)->compile(cenv)); + size_t np = 0; + vector<Value*> params(code.size() - 1); + vector<AST*>::const_iterator a = code.begin(); + for (++a; a != code.end(); ++a) + params[np++] = (*a)->compile(cenv); switch (params.size()) { case 0: @@ -644,9 +689,24 @@ ASTPrimitive::compile(CEnv& cenv) static void repl(CEnv& cenv, ExecutionEngine* engine) { + // Set up our language PEnv penv; + penv.reg(penv.sym("if"), PEnv::Parser(parseIf, 0)); + penv.reg(penv.sym("def"), PEnv::Parser(parseDef, 0)); + penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, 0)); + penv.reg(penv.sym("+"), PEnv::Parser(parsePrim, Instruction::Add)); + penv.reg(penv.sym("-"), PEnv::Parser(parsePrim, Instruction::Sub)); + penv.reg(penv.sym("*"), PEnv::Parser(parsePrim, Instruction::Mul)); + penv.reg(penv.sym("/"), PEnv::Parser(parsePrim, Instruction::FDiv)); + penv.reg(penv.sym("%"), PEnv::Parser(parsePrim, Instruction::FRem)); + penv.reg(penv.sym("&"), PEnv::Parser(parsePrim, Instruction::And)); + penv.reg(penv.sym("|"), PEnv::Parser(parsePrim, Instruction::Or)); + penv.reg(penv.sym("^"), PEnv::Parser(parsePrim, Instruction::Xor)); + cenv.def(penv.sym("true"), new ASTLiteral<bool>(true)); + cenv.def(penv.sym("false"), new ASTLiteral<bool>(false)); + while (1) { - std::cout << "> "; + std::cout << "(=>) "; std::cout.flush(); SExp exp = readExpression(std::cin); if (exp.type == SExp::LIST && exp.list.empty()) @@ -654,12 +714,12 @@ repl(CEnv& cenv, ExecutionEngine* engine) try { AST* ast = parseExpression(penv, exp); - //const Type* type = ast->type(cenv); lambdaLift(cenv, ast); ASTPrototype* proto = new ASTPrototype(); ASTFunction* func = new ASTFunction(proto, ast); - Function* code = func->compile(cenv, cenv.gensym("repl")); + Function* code = func->compile(cenv, cenv.gensym("_repl")); void* fp = engine->getPointerToFunction(code); + code->dump(); double (*f)() = (double (*)())fp; std::cout << f() << " :: "; func->body->type(cenv)->print(std::cout); @@ -690,4 +750,3 @@ main() module->dump(); return 0; } - |