diff options
author | David Robillard <d@drobilla.net> | 2009-01-26 04:39:31 +0000 |
---|---|---|
committer | David Robillard <d@drobilla.net> | 2009-01-26 04:39:31 +0000 |
commit | 022f55e2ab4da12ae45321c7f2cca71b66c417a4 (patch) | |
tree | e3a8fc7d33f63b467dc005eb9fdd749bfb680b0e /ll.cpp | |
parent | 57951dddc871bb8afd681f8205db29fb653b3a58 (diff) | |
download | resp-022f55e2ab4da12ae45321c7f2cca71b66c417a4.tar.gz resp-022f55e2ab4da12ae45321c7f2cca71b66c417a4.tar.bz2 resp-022f55e2ab4da12ae45321c7f2cca71b66c417a4.zip |
Type inference.
git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@15 ad02d1e2-f140-0410-9f75-f8b11f17cedd
Diffstat (limited to 'll.cpp')
-rw-r--r-- | ll.cpp | 608 |
1 files changed, 363 insertions, 245 deletions
@@ -69,25 +69,20 @@ readExpression(std::istream& in) stack<SExp> stk; string tok; -#define APPEND_TOK() \ - if (stk.empty()) return tok; else stk.top().list.push_back(SExp(tok)) +#define APPEND_TOK() { if (stk.empty()) { return tok; } else {\ + stk.top().list.push_back(SExp(tok)); tok = ""; } } while (char ch = in.get()) { switch (ch) { case EOF: return SExp(); case ' ': case '\t': case '\n': - if (tok == "") - continue; - else - APPEND_TOK(); - tok = ""; + if (tok != "") APPEND_TOK(); break; case '"': do { tok.push_back(ch); } while ((ch = in.get()) != '"'); tok.push_back('"'); APPEND_TOK(); - tok = ""; break; case '(': stk.push(SExp()); @@ -95,8 +90,7 @@ readExpression(std::istream& in) case ')': switch (stk.size()) { case 0: - throw SyntaxError("Missing '('"); - break; + throw SyntaxError("Unexpected ')'"); case 1: if (tok != "") stk.top().list.push_back(SExp(tok)); return stk.top(); @@ -122,105 +116,29 @@ readExpression(std::istream& in) } - -/*************************************************************************** - * Environment * - ***************************************************************************/ - -class AST; -class ASTSymbol; - -/// Generic Recursive Environment (stack of key:value dictionaries) -template<typename K, typename V> -struct Env : public list< map<K,V> > { - Env() : list< map<K, V> >(1) {} - void push() { this->push_front(map<K,V>()); } - void push(const map<K,V>& frame) { this->push_front(frame); } - map<K,V>& pop() { - map<K,V>& front = this->front(); - this->pop_front(); - return front; - } - void def(const K& k, const V& v) { - if (this->front().find(k) != this->front().end()) - throw SyntaxError("Redefinition"); - this->front()[k] = v; - } - V* ref(const K& name) { - typename Env::iterator i = this->begin(); - for (; i != this->end(); ++i) { - typename map<K,V>::iterator s = i->find(name); - if (s != i->end()) - return &s->second; - } - return 0; - } -}; - -class PEnv; - -/// Compile-time environment -struct CEnv { - CEnv(PEnv& p, Module* m, const TargetData* target) - : penv(p), module(m), emp(module), fpm(&emp), symID(0), tID(0) - { - // Set up the optimizer pipeline. - // Register info about how the target lays out data structures. - fpm.add(new TargetData(*target)); - // Do simple "peephole" and bit-twiddling optimizations. - fpm.add(createInstructionCombiningPass()); - // Reassociate expressions. - fpm.add(createReassociatePass()); - // Eliminate Common SubExpressions. - fpm.add(createGVNPass()); - // Simplify control flow graph (delete unreachable blocks, etc). - fpm.add(createCFGSimplificationPass()); - } - string gensym(const char* base="_") { - ostringstream s; s << base << symID++; return s.str(); - } - typedef Env<const AST*, AST*> Code; - typedef Env<const ASTSymbol*, Value*> Vals; - - PEnv& penv; - IRBuilder<> builder; - Module* module; - ExistingModuleProvider emp; - FunctionPassManager fpm; - unsigned symID; - unsigned tID; - Code code; - Vals vals; -}; - -/// LLVM Operation -struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; }; - - - /*************************************************************************** * Abstract Syntax Tree * ***************************************************************************/ -struct AType; -struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; - -struct CEnv; ///< Compile Time Environment +struct TEnv; ///< Type-Time Environment +struct CEnv; ///< Compile-Time Environment +struct AType; ///< Abstract Type /// Base class for all AST nodes struct AST { virtual ~AST() {} - virtual string str(CEnv& cenv) const = 0; - virtual AType* type(CEnv& cenv) = 0; - virtual Value* compile(CEnv& cenv) = 0; - virtual void lift(CEnv& cenv) {} + virtual bool operator==(const AST& rhs) const = 0; + virtual string str() const = 0; + virtual void constrain(TEnv& tenv) const {} + virtual void lift(CEnv& cenv) {} + virtual Value* compile(CEnv& cenv) = 0; }; /// Symbol, e.g. "a" struct ASTSymbol : public AST { ASTSymbol(const string& s) : cppstr(s) {} - std::string str(CEnv&) const { return cppstr; } - AType* type(CEnv& cenv); + bool operator==(const AST& rhs) const { return this == &rhs; } + string str() const { return cppstr; } Value* compile(CEnv& cenv); private: const string cppstr; @@ -229,33 +147,47 @@ private: /// Tuple (heterogeneous sequence of known length), e.g. "(a b c)" struct ASTTuple : public AST { ASTTuple(vector<AST*> t=vector<AST*>()) : tup(t) {} - string str(CEnv& cenv) const { + string str() const { string ret = "("; for (size_t i = 0; i != tup.size(); ++i) - ret += tup[i]->str(cenv) + ((i != tup.size() - 1) ? " " : ""); + ret += tup[i]->str() + ((i != tup.size() - 1) ? " " : ""); ret.append(")"); return ret; } + bool operator==(const AST& rhs) const { + const ASTTuple* rhst = dynamic_cast<const ASTTuple*>(&rhs); + if (!rhst) return false; + if (rhst->tup.size() != tup.size()) return false; + for (size_t i = 0; i < tup.size(); ++i) + if (tup[i] != rhst->tup[i]) + return false; + return true; + } + bool operator!=(const ASTTuple& t) const { return ! operator==(t); } void lift(CEnv& cenv) { FOREACH(vector<AST*>::iterator, t, tup) (*t)->lift(cenv); } - AType* type(CEnv& cenv); + void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv) { return NULL; } vector<AST*> tup; }; -/// TExpr ::= (TName TExpr*) | ?Num +/// Type Expression ::= (TName TExpr*) | ?Num struct AType : public ASTTuple { + AType(const vector<AST*>& t) : ASTTuple(t), var(false), ctype(0) {} AType(unsigned i) : var(true), ctype(0), id(id) {} - AType(const string& n, const Type* t) : var(false), ctype(t) { - tup.push_back(new ASTSymbol(n)); + AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) { + tup.push_back(n); } - AType(const vector<AST*>& t) : ASTTuple(t), var(false), ctype(0) {} - inline bool operator==(const AType& t) const { return tup[0] == t.tup[0]; } - inline bool operator!=(const AType& t) const { return tup[0] != t.tup[0]; } - string str(CEnv& cenv) const { return var ? "?" : ASTTuple::str(cenv); } - AType* type(CEnv& cenv) { return this; } + string str() const { + if (var) { + ostringstream s; s << "?" << id; return s.str(); + } else { + return ASTTuple::str(); + } + } + void constrain(TEnv& tenv) const {} Value* compile(CEnv& cenv) { return NULL; } bool var; const Type* ctype; @@ -266,103 +198,49 @@ struct AType : public ASTTuple { template<typename VT> struct ASTLiteral : public AST { ASTLiteral(VT v) : val(v) {} - string str(CEnv& env) const { return "(Literal)"; } - AType* type(CEnv& cenv); + bool operator==(const AST& rhs) const { + const ASTLiteral<VT>* rhsl = dynamic_cast<const ASTLiteral<VT>*>(&rhs); + return rhsl && val == rhsl->val; + } + string str() const { ostringstream s; s << val; return s.str(); } + void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); const VT val; }; -#define LITERAL(CT, VT, NAME, COMPILED) \ -template<> string \ -ASTLiteral<CT>::str(CEnv& cenv) const { return NAME; } \ -template<> AType* \ -ASTLiteral<CT>::type(CEnv& cenv) { return new AType(NAME, VT); } \ -template<> Value* \ -ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } - -/// Literal template instantiations -LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, true)); -LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val)); -LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); - -typedef Op UD; // User Data argument for parse functions - -// Parse Time Environment (symbol table) -struct PEnv : private map<const string, ASTSymbol*> { - typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function - struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; - map<const ASTSymbol*, Parser> parsers; - void reg(const ASTSymbol* s, const Parser& p) { - parsers.insert(make_pair(s, p)); - } - const Parser* parser(const ASTSymbol* s) const { - map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s); - return (i != parsers.end()) ? &i->second : NULL; - } - ASTSymbol* sym(const string& s) { - const const_iterator i = find(s); - return ((i != end()) - ? i->second - : insert(make_pair(s, new ASTSymbol(s))).first->second); - } -}; /// Closure (first-class function with captured lexical bindings) struct ASTClosure : public AST { ASTClosure(ASTTuple* p, AST* b) : prot(p), body(b), func(0) {} - string str(CEnv& env) const { return "(fn)"; } - AType* type(CEnv& cenv) { - vector<AST*> texp(3); - texp[0] = cenv.penv.sym("Fn"); - texp[1] = prot; - texp[2] = body; - return new AType(texp); - } + bool operator==(const AST& rhs) const { return this == &rhs; } + string str() const { return "(fn)"; } + void constrain(TEnv& tenv) const; + void lift(CEnv& cenv); Value* compile(CEnv& cenv); - void lift(CEnv& cenv); ASTTuple* const prot; AST* const body; - vector<const ASTSymbol*> bindings; private: Function* func; }; - + /// Function call/application, e.g. "(func arg1 arg2)" struct ASTCall : public ASTTuple { ASTCall(const vector<AST*>& t) : ASTTuple(t) {} - AType* type(CEnv& cenv) { - AST* callee = tup[0]; - ASTSymbol* sym = dynamic_cast<ASTSymbol*>(tup[0]); - if (sym) { - AST** val = cenv.code.ref(sym); - if (val) - callee = *val; - } - ASTClosure* c = dynamic_cast<ASTClosure*>(callee); - if (!c) throw TypeError("Call to non-closure"); - return c->body->type(cenv); - } - void lift(CEnv& cenv); + void constrain(TEnv& tenv) const; + void lift(CEnv& cenv); Value* compile(CEnv& cenv); }; /// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))" struct ASTDefinition : public ASTCall { ASTDefinition(const vector<AST*>& c) : ASTCall(c) {} - AType* type(CEnv& cenv) { return tup[2]->type(cenv); } + void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); }; /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { ASTIf(const vector<AST*>& c) : ASTCall(c) {} - AType* type(CEnv& cenv) { - AType* cT = tup[1]->type(cenv); - AType* tT = tup[2]->type(cenv); - AType* eT = tup[3]->type(cenv); - if (cT->ctype != Type::Int1Ty) throw TypeError("If condition is not a boolean"); - if (*tT != *eT) throw TypeError("If branches have different types"); - return tT; - } + void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); }; @@ -370,26 +248,42 @@ struct ASTIf : public ASTCall { struct ASTPrimitive : public ASTCall { ASTPrimitive(const vector<AST*>& c, unsigned o, unsigned a=0) : ASTCall(c), op(o), arg(a) {} - AType* type(CEnv& cenv); + void constrain(TEnv& tenv) const; Value* compile(CEnv& cenv); unsigned op; unsigned arg; }; -AType* -ASTTuple::type(CEnv& cenv) -{ - vector<AST*> texp; - FOREACH(vector<AST*>::const_iterator, p, tup) - texp.push_back((*p)->type(cenv)); - return new AType(texp); -} - /*************************************************************************** * Parser - S-Expressions (SExp) -> AST Nodes (AST) * ***************************************************************************/ +/// LLVM Operation +struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; }; + +typedef Op UD; // User Data argument for parse functions + +// Parse Time Environment (symbol table) +struct PEnv : private map<const string, ASTSymbol*> { + typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function + struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; }; + map<const ASTSymbol*, Parser> parsers; + void reg(const ASTSymbol* s, const Parser& p) { + parsers.insert(make_pair(s, p)); + } + const Parser* parser(const ASTSymbol* s) const { + map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s); + return (i != parsers.end()) ? &i->second : NULL; + } + ASTSymbol* sym(const string& s) { + const const_iterator i = find(s); + return ((i != end()) + ? i->second + : insert(make_pair(s, new ASTSymbol(s))).first->second); + } +}; + /// The fundamental parser method static AST* parseExpression(PEnv& penv, const SExp& exp) @@ -445,39 +339,286 @@ static AST* parsePrim(PEnv& penv, const list<SExp>& c, UD data) { return new ASTPrimitive(pmap(penv, c), data.op, data.arg); } -static ASTTuple* -parsePrototype(PEnv& penv, const SExp& e, UD) - { return new ASTTuple(pmap(penv, e.list)); } - static AST* parseFn(PEnv& penv, const list<SExp>& c, UD) { list<SExp>::const_iterator a = c.begin(); ++a; return new ASTClosure( - parsePrototype(penv, *a++, UD()), + new ASTTuple(pmap(penv, (*a++).list)), parseExpression(penv, *a++)); } /*************************************************************************** + * Lexical Environment * + ***************************************************************************/ + +template<typename K, typename V> +struct Env : public list< map<K,V> > { + typedef map<K,V> Frame; + Env() : list<Frame>(1) {} + void push_front() { list<Frame>::push_front(Frame()); } + void def(const K& k, const V& v) { + if (this->front().find(k) != this->front().end()) + throw SyntaxError("Redefinition"); + this->front()[k] = v; + } + V* ref(const K& name) { + typename Frame::iterator s; + for (typename Env::iterator i = this->begin(); i != this->end(); ++i) + if ((s = i ->find(name)) != i->end()) + return &s->second; + return 0; + } +}; + + +/*************************************************************************** + * Typing * + ***************************************************************************/ + +struct TypeError : public Error { TypeError (const char* m) : Error(m) {} }; + +/// Type-Time Environment +struct TEnv { + TEnv(PEnv& p) : penv(p), varID(0) {} + typedef map<const AST*, AType*> Types; + typedef multimap<const AST*, AType*> Constraints; + AType* var() { return new AType(varID++); } + AType* type(const AST* ast) { + Types::iterator t = types.find(ast); + if (t != types.end()) + return t->second; + else + return (types[ast] = var()); + } + AType* named(const string& name) const { + Types::const_iterator i = namedTypes.find(penv.sym(name)); + if (i == namedTypes.end()) throw TypeError("Unknown named type"); + return i->second; + } + void name(const string& name, const Type* type) { + ASTSymbol* sym = penv.sym(name); + namedTypes[sym] = new AType(penv.sym(name), type); + } + void constrain(const AST* ast, AType* type) { + constraints.insert(make_pair(ast, type)); + } + void unify(); + + PEnv& penv; + Types types; + Types namedTypes; + Constraints constraints; + unsigned varID; +}; + +#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) + +void +ASTTuple::constrain(TEnv& tenv) const +{ + vector<AST*> texp; + FOREACH(vector<AST*>::const_iterator, p, tup) { + (*p)->constrain(tenv); + AType* tvar = tenv.var(); + texp.push_back(tvar); + tenv.constrain(tvar, tenv.type(*p)); + } + tenv.constrain(this, new AType(texp)); +} + +void +ASTClosure::constrain(TEnv& tenv) const +{ + prot->constrain(tenv); + body->constrain(tenv); + vector<AST*> texp(3); + texp[0] = tenv.penv.sym("Fn"); + texp[1] = prot; + texp[2] = body; + AType* tvar = tenv.var(); + tenv.constrain(texp[2], tvar); + tenv.constrain(this, tvar); +} + +void +ASTCall::constrain(TEnv& tenv) const +{ + ASTTuple::constrain(tenv); +#if 0 + AST* callee = tup[0]; + ASTSymbol* sym = dynamic_cast<ASTSymbol*>(tup[0]); + if (sym) { + AST** val = tenv.code.ref(sym); + if (val) + callee = *val; + } + ASTClosure* c = dynamic_cast<ASTClosure*>(callee); + if (!c) throw TypeError("Call to non-closure"); + tenv.contraints[this] = c->body->type(tenv); +#endif +} + +void +ASTDefinition::constrain(TEnv& tenv) const +{ + FOREACH(vector<AST*>::const_iterator, p, tup) + (*p)->constrain(tenv); + AType* tvar = tenv.var(); + tenv.constrain(tup[1], tvar); + tenv.constrain(tup[2], tvar); + tenv.constrain(this, tvar); +} + +void +ASTIf::constrain(TEnv& tenv) const +{ + FOREACH(vector<AST*>::const_iterator, p, tup) + (*p)->constrain(tenv); + AType* tvar = tenv.var(); + tenv.constrain(tup[1], tenv.named("Bool")); + tenv.constrain(tup[2], tvar); + tenv.constrain(tup[3], tvar); + tenv.constrain(this, tvar); +} + +void +ASTPrimitive::constrain(TEnv& tenv) const +{ + FOREACH(vector<AST*>::const_iterator, p, tup) + (*p)->constrain(tenv); + if (OP_IS_A(op, Instruction::BinaryOps)) { + if (tup.size() <= 1) throw SyntaxError("Primitive call with 0 args"); + AType* tvar = tenv.var(); + for (size_t i = 1; i < tup.size(); ++i) + tenv.constrain(tup[i], tvar); + tenv.constrain(this, tvar); + } else if (op == Instruction::ICmp) { + if (tup.size() != 3) throw SyntaxError("Comparison call with != 2 args"); + tenv.constrain(tup[1], tenv.type(tup[2])); + tenv.constrain(this, tenv.named("Bool")); + } else { + throw TypeError("Unknown primitive"); + } +} + +void +TEnv::unify() +{ + typedef map<const AType*, AType*> Substitutions; + + bool progress = false; + do { + progress = false; + //std::cout << "========" << endl; + Substitutions subst; + for (Constraints::iterator c = constraints.begin(); c != constraints.end();) { + Constraints::iterator next = c; + ++next; + const AST* o = c->first; + AType* t = c->second; + //std::cout << "Constraint: " << o->str() << " = " << t->str() << endl; + if (t->var) { + Types::iterator ot = types.find(o); + if (ot != types.end()) + subst[t] = ot->second; + } else { + Types::iterator ot = types.find(o); + if (ot == types.end()) { + //std::cout << "Resolve: " << o->str() << endl; + types.insert(make_pair(o, t)); + constraints.erase(c); + } + } + c = next; + } + + for (Substitutions::iterator s = subst.begin(); s != subst.end(); ++s) { + for (Constraints::iterator c = constraints.begin(); c != constraints.end(); ++c) { + if (c->second == s->first) { + //std::cout << c->second->str() << " => " << s->second->str() << endl; + c->second = s->second; + progress = true; + } + } + } + + } while (progress); + + //std::cout << "======== Done unification" << endl; + + constraints.clear(); +} + + +/*************************************************************************** * Code Generation * ***************************************************************************/ struct CompileError : public Error { CompileError(const char* m) : Error(m) {} }; +class PEnv; + +/// Compile-Time Environment +struct CEnv { + CEnv(PEnv& p, Module* m, const TargetData* target) + : penv(p), tenv(p), module(m), emp(module), fpm(&emp), symID(0) + { + // Set up the optimizer pipeline. + // Register info about how the target lays out data structures. + fpm.add(new TargetData(*target)); + // Do simple "peephole" and bit-twiddling optimizations. + fpm.add(createInstructionCombiningPass()); + // Reassociate expressions. + fpm.add(createReassociatePass()); + // Eliminate Common SubExpressions. + fpm.add(createGVNPass()); + // Simplify control flow graph (delete unreachable blocks, etc). + fpm.add(createCFGSimplificationPass()); + } + string gensym(const char* base="_") { + ostringstream s; s << base << symID++; return s.str(); + } + typedef Env<const AST*, AST*> Code; + typedef Env<const ASTSymbol*, Value*> Vals; + + PEnv& penv; + TEnv tenv; + IRBuilder<> builder; + Module* module; + ExistingModuleProvider emp; + FunctionPassManager fpm; + unsigned symID; + Code code; + Vals vals; +}; + +#define LITERAL(CT, VT, NAME, COMPILED) \ +template<> Value* \ +ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \ +template<> void \ +ASTLiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } + +/// Literal template instantiations +LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, true)); +LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val)); +LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); + static Function* compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT) { Function::LinkageTypes linkage = Function::ExternalLinkage; - const vector<AST*>& texp = prot.type(cenv)->tup; + const vector<AST*>& texp = cenv.tenv.type(&prot)->tup; vector<const Type*> cprot; for (size_t i = 0; i < texp.size(); ++i) { - const Type* t = texp[i]->type(cenv)->ctype; + const Type* t = cenv.tenv.type(texp[i])->ctype; if (!t) throw CompileError("Function prototype contains NULL"); cprot.push_back(t); } + if (!retT) throw CompileError("Function return value type is NULL"); FunctionType* fT = FunctionType::get(retT, cprot, false); Function* f = Function::Create(fT, linkage, name, cenv.module); @@ -489,17 +630,11 @@ compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); for (size_t i = 0; i != prot.tup.size(); ++a, ++i) - a->setName(prot.tup[i]->str(cenv)); + a->setName(prot.tup[i]->str()); return f; } -AType* -ASTSymbol::type(CEnv& cenv) -{ - AST** t = cenv.code.ref(this); - return t ? (*t)->type(cenv) : new AType(cenv.tID++); -} Value* ASTSymbol::compile(CEnv& cenv) @@ -552,7 +687,7 @@ ASTCall::lift(CEnv& cenv) tup[i]->lift(cenv); // Extend environment with bound and typed parameters - cenv.code.push(); + cenv.code.push_front(); if (c->prot->tup.size() != tup.size() - 1) throw CompileError("Call to closure with mismatched arguments"); @@ -562,7 +697,7 @@ ASTCall::lift(CEnv& cenv) // Lift callee closure tup[0]->lift(cenv); - cenv.code.pop(); + cenv.code.pop_front(); } Value* @@ -576,9 +711,9 @@ ASTCall::compile(CEnv& cenv) if (!c) throw CompileError("Call to non-closure"); Value* v = c->compile(cenv); - if (!v) throw SyntaxError("Callee failed to compile"); + if (!v) throw CompileError("Callee failed to compile"); Function* f = dynamic_cast<Function*>(c->compile(cenv)); - if (!f) throw SyntaxError("Callee compiled to non-function"); + if (!f) throw CompileError("Callee compiled to non-function"); vector<Value*> params; for (size_t i = 1; i < tup.size(); ++i) @@ -603,25 +738,23 @@ ASTIf::compile(CEnv& cenv) // Emit then value. cenv.builder.SetInsertPoint(thenBB); - Value* thenV = tup[2]->compile(cenv); + Value* thenV = tup[2]->compile(cenv); // Can change current block, so... cenv.builder.CreateBr(mergeBB); - // compile of 'Then' can change the current block, update thenBB - thenBB = cenv.builder.GetInsertBlock(); + thenBB = cenv.builder.GetInsertBlock(); // ... update thenBB afterwards // Emit else block. parent->getBasicBlockList().push_back(elseBB); cenv.builder.SetInsertPoint(elseBB); - Value* elseV = tup[3]->compile(cenv); + Value* elseV = tup[3]->compile(cenv); // Can change current block, so... cenv.builder.CreateBr(mergeBB); - // compile of 'Else' can change the current block, update elseBB - elseBB = cenv.builder.GetInsertBlock(); + elseBB = cenv.builder.GetInsertBlock(); // ... update elseBB afterwards // Emit merge block. parent->getBasicBlockList().push_back(mergeBB); cenv.builder.SetInsertPoint(mergeBB); - PHINode* pn = cenv.builder.CreatePHI(type(cenv)->ctype, "iftmp"); + PHINode* pn = cenv.builder.CreatePHI(cenv.tenv.type(this)->ctype, "iftmp"); pn->addIncoming(thenV, thenBB); pn->addIncoming(elseV, elseBB); @@ -635,23 +768,16 @@ ASTClosure::lift(CEnv& cenv) // Can't lift a closure with variable types (lift later when called) for (size_t i = 0; i < prot->tup.size(); ++i) - if (prot->tup[i]->type(cenv)->var) + if (cenv.tenv.type(prot->tup[i])->var) return; - cenv.code.push(); + if (cenv.tenv.type(body)->var) + return; + + cenv.code.push_front(); - ASTSymbol* sym = dynamic_cast<ASTSymbol*>(body); - if (sym) { - AST** obj = cenv.code.ref(sym); - if (!obj) { - std::cout << "UNDEFINED SYMBOL BODY\n"; - prot->tup.push_back(sym); - bindings.push_back(sym); - } - } - // Write function declaration - Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, body->type(cenv)->ctype); + Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(body)->ctype); BasicBlock* bb = BasicBlock::Create("entry", f); cenv.builder.SetInsertPoint(bb); @@ -673,7 +799,7 @@ ASTClosure::lift(CEnv& cenv) throw e; } - cenv.code.pop(); + cenv.code.pop_front(); } Value* @@ -683,22 +809,6 @@ ASTClosure::compile(CEnv& cenv) return func; } -#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) - -AType* -ASTPrimitive::type(CEnv& cenv) -{ - if (OP_IS_A(op, Instruction::BinaryOps)) { - if (tup.size() <= 1) throw SyntaxError("Primitive call with 0 args"); - return tup[1]->type(cenv); - } else if (op == Instruction::ICmp) { - if (tup.size() != 3) throw SyntaxError("Comparison call with != 2 args"); - return new AType("Bool", Type::Int1Ty); - } else { - throw CompileError("Unknown primitive"); - } -} - Value* ASTPrimitive::compile(CEnv& cenv) { @@ -723,7 +833,7 @@ ASTPrimitive::compile(CEnv& cenv) return val; } } else if (op == Instruction::ICmp) { - bool isInt = tup[1]->type(cenv)->str(cenv) == "(Int)"; + bool isInt = cenv.tenv.type(tup[1])->str() == "(Int)"; if (isInt) { return cenv.builder.CreateICmp((CmpInst::Predicate)arg, a, b); } else { @@ -778,6 +888,10 @@ main() cenv.code.def(penv.sym("true"), new ASTLiteral<bool>(true)); cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false)); + cenv.tenv.name("Bool", Type::Int1Ty); + cenv.tenv.name("Int", Type::Int32Ty); + cenv.tenv.name("Float", Type::FloatTy); + while (1) { std::cout << "(=>) "; std::cout.flush(); @@ -786,9 +900,13 @@ main() break; try { - AST* body = parseExpression(penv, exp); + AST* body = parseExpression(penv, exp); + + body->constrain(cenv.tenv); + cenv.tenv.unify(); + ASTTuple* prot = new ASTTuple(); - AType* bodyT = body->type(cenv); + AType* bodyT = cenv.tenv.type(body); if (!bodyT) throw TypeError("REPL call to untyped body"); if (bodyT->var) throw TypeError("REPL call to variable typed body"); @@ -825,7 +943,7 @@ main() Value* val = body->compile(cenv); std::cout << val; } - std::cout << " : " << body->type(cenv)->str(cenv) << endl; + std::cout << " : " << cenv.tenv.type(body)->str() << endl; } catch (Error e) { std::cerr << "Error: " << e.what() << endl; |