aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Robillard <d@drobilla.net>2009-01-26 04:39:31 +0000
committerDavid Robillard <d@drobilla.net>2009-01-26 04:39:31 +0000
commit022f55e2ab4da12ae45321c7f2cca71b66c417a4 (patch)
treee3a8fc7d33f63b467dc005eb9fdd749bfb680b0e
parent57951dddc871bb8afd681f8205db29fb653b3a58 (diff)
downloadresp-022f55e2ab4da12ae45321c7f2cca71b66c417a4.tar.gz
resp-022f55e2ab4da12ae45321c7f2cca71b66c417a4.tar.bz2
resp-022f55e2ab4da12ae45321c7f2cca71b66c417a4.zip
Type inference.
git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@15 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r--ll.cpp608
1 files changed, 363 insertions, 245 deletions
diff --git a/ll.cpp b/ll.cpp
index a9e80c5..cc971a5 100644
--- a/ll.cpp
+++ b/ll.cpp
@@ -69,25 +69,20 @@ readExpression(std::istream& in)
stack<SExp> stk;
string tok;
-#define APPEND_TOK() \
- if (stk.empty()) return tok; else stk.top().list.push_back(SExp(tok))
+#define APPEND_TOK() { if (stk.empty()) { return tok; } else {\
+ stk.top().list.push_back(SExp(tok)); tok = ""; } }
while (char ch = in.get()) {
switch (ch) {
case EOF:
return SExp();
case ' ': case '\t': case '\n':
- if (tok == "")
- continue;
- else
- APPEND_TOK();
- tok = "";
+ if (tok != "") APPEND_TOK();
break;
case '"':
do { tok.push_back(ch); } while ((ch = in.get()) != '"');
tok.push_back('"');
APPEND_TOK();
- tok = "";
break;
case '(':
stk.push(SExp());
@@ -95,8 +90,7 @@ readExpression(std::istream& in)
case ')':
switch (stk.size()) {
case 0:
- throw SyntaxError("Missing '('");
- break;
+ throw SyntaxError("Unexpected ')'");
case 1:
if (tok != "") stk.top().list.push_back(SExp(tok));
return stk.top();
@@ -122,105 +116,29 @@ readExpression(std::istream& in)
}
-
-/***************************************************************************
- * Environment *
- ***************************************************************************/
-
-class AST;
-class ASTSymbol;
-
-/// Generic Recursive Environment (stack of key:value dictionaries)
-template<typename K, typename V>
-struct Env : public list< map<K,V> > {
- Env() : list< map<K, V> >(1) {}
- void push() { this->push_front(map<K,V>()); }
- void push(const map<K,V>& frame) { this->push_front(frame); }
- map<K,V>& pop() {
- map<K,V>& front = this->front();
- this->pop_front();
- return front;
- }
- void def(const K& k, const V& v) {
- if (this->front().find(k) != this->front().end())
- throw SyntaxError("Redefinition");
- this->front()[k] = v;
- }
- V* ref(const K& name) {
- typename Env::iterator i = this->begin();
- for (; i != this->end(); ++i) {
- typename map<K,V>::iterator s = i->find(name);
- if (s != i->end())
- return &s->second;
- }
- return 0;
- }
-};
-
-class PEnv;
-
-/// Compile-time environment
-struct CEnv {
- CEnv(PEnv& p, Module* m, const TargetData* target)
- : penv(p), module(m), emp(module), fpm(&emp), symID(0), tID(0)
- {
- // Set up the optimizer pipeline.
- // Register info about how the target lays out data structures.
- fpm.add(new TargetData(*target));
- // Do simple "peephole" and bit-twiddling optimizations.
- fpm.add(createInstructionCombiningPass());
- // Reassociate expressions.
- fpm.add(createReassociatePass());
- // Eliminate Common SubExpressions.
- fpm.add(createGVNPass());
- // Simplify control flow graph (delete unreachable blocks, etc).
- fpm.add(createCFGSimplificationPass());
- }
- string gensym(const char* base="_") {
- ostringstream s; s << base << symID++; return s.str();
- }
- typedef Env<const AST*, AST*> Code;
- typedef Env<const ASTSymbol*, Value*> Vals;
-
- PEnv& penv;
- IRBuilder<> builder;
- Module* module;
- ExistingModuleProvider emp;
- FunctionPassManager fpm;
- unsigned symID;
- unsigned tID;
- Code code;
- Vals vals;
-};
-
-/// LLVM Operation
-struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; };
-
-
-
/***************************************************************************
* Abstract Syntax Tree *
***************************************************************************/
-struct AType;
-struct TypeError : public Error { TypeError (const char* m) : Error(m) {} };
-
-struct CEnv; ///< Compile Time Environment
+struct TEnv; ///< Type-Time Environment
+struct CEnv; ///< Compile-Time Environment
+struct AType; ///< Abstract Type
/// Base class for all AST nodes
struct AST {
virtual ~AST() {}
- virtual string str(CEnv& cenv) const = 0;
- virtual AType* type(CEnv& cenv) = 0;
- virtual Value* compile(CEnv& cenv) = 0;
- virtual void lift(CEnv& cenv) {}
+ virtual bool operator==(const AST& rhs) const = 0;
+ virtual string str() const = 0;
+ virtual void constrain(TEnv& tenv) const {}
+ virtual void lift(CEnv& cenv) {}
+ virtual Value* compile(CEnv& cenv) = 0;
};
/// Symbol, e.g. "a"
struct ASTSymbol : public AST {
ASTSymbol(const string& s) : cppstr(s) {}
- std::string str(CEnv&) const { return cppstr; }
- AType* type(CEnv& cenv);
+ bool operator==(const AST& rhs) const { return this == &rhs; }
+ string str() const { return cppstr; }
Value* compile(CEnv& cenv);
private:
const string cppstr;
@@ -229,33 +147,47 @@ private:
/// Tuple (heterogeneous sequence of known length), e.g. "(a b c)"
struct ASTTuple : public AST {
ASTTuple(vector<AST*> t=vector<AST*>()) : tup(t) {}
- string str(CEnv& cenv) const {
+ string str() const {
string ret = "(";
for (size_t i = 0; i != tup.size(); ++i)
- ret += tup[i]->str(cenv) + ((i != tup.size() - 1) ? " " : "");
+ ret += tup[i]->str() + ((i != tup.size() - 1) ? " " : "");
ret.append(")");
return ret;
}
+ bool operator==(const AST& rhs) const {
+ const ASTTuple* rhst = dynamic_cast<const ASTTuple*>(&rhs);
+ if (!rhst) return false;
+ if (rhst->tup.size() != tup.size()) return false;
+ for (size_t i = 0; i < tup.size(); ++i)
+ if (tup[i] != rhst->tup[i])
+ return false;
+ return true;
+ }
+ bool operator!=(const ASTTuple& t) const { return ! operator==(t); }
void lift(CEnv& cenv) {
FOREACH(vector<AST*>::iterator, t, tup)
(*t)->lift(cenv);
}
- AType* type(CEnv& cenv);
+ void constrain(TEnv& tenv) const;
Value* compile(CEnv& cenv) { return NULL; }
vector<AST*> tup;
};
-/// TExpr ::= (TName TExpr*) | ?Num
+/// Type Expression ::= (TName TExpr*) | ?Num
struct AType : public ASTTuple {
+ AType(const vector<AST*>& t) : ASTTuple(t), var(false), ctype(0) {}
AType(unsigned i) : var(true), ctype(0), id(id) {}
- AType(const string& n, const Type* t) : var(false), ctype(t) {
- tup.push_back(new ASTSymbol(n));
+ AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) {
+ tup.push_back(n);
}
- AType(const vector<AST*>& t) : ASTTuple(t), var(false), ctype(0) {}
- inline bool operator==(const AType& t) const { return tup[0] == t.tup[0]; }
- inline bool operator!=(const AType& t) const { return tup[0] != t.tup[0]; }
- string str(CEnv& cenv) const { return var ? "?" : ASTTuple::str(cenv); }
- AType* type(CEnv& cenv) { return this; }
+ string str() const {
+ if (var) {
+ ostringstream s; s << "?" << id; return s.str();
+ } else {
+ return ASTTuple::str();
+ }
+ }
+ void constrain(TEnv& tenv) const {}
Value* compile(CEnv& cenv) { return NULL; }
bool var;
const Type* ctype;
@@ -266,103 +198,49 @@ struct AType : public ASTTuple {
template<typename VT>
struct ASTLiteral : public AST {
ASTLiteral(VT v) : val(v) {}
- string str(CEnv& env) const { return "(Literal)"; }
- AType* type(CEnv& cenv);
+ bool operator==(const AST& rhs) const {
+ const ASTLiteral<VT>* rhsl = dynamic_cast<const ASTLiteral<VT>*>(&rhs);
+ return rhsl && val == rhsl->val;
+ }
+ string str() const { ostringstream s; s << val; return s.str(); }
+ void constrain(TEnv& tenv) const;
Value* compile(CEnv& cenv);
const VT val;
};
-#define LITERAL(CT, VT, NAME, COMPILED) \
-template<> string \
-ASTLiteral<CT>::str(CEnv& cenv) const { return NAME; } \
-template<> AType* \
-ASTLiteral<CT>::type(CEnv& cenv) { return new AType(NAME, VT); } \
-template<> Value* \
-ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); }
-
-/// Literal template instantiations
-LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, true));
-LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val));
-LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false));
-
-typedef Op UD; // User Data argument for parse functions
-
-// Parse Time Environment (symbol table)
-struct PEnv : private map<const string, ASTSymbol*> {
- typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function
- struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; };
- map<const ASTSymbol*, Parser> parsers;
- void reg(const ASTSymbol* s, const Parser& p) {
- parsers.insert(make_pair(s, p));
- }
- const Parser* parser(const ASTSymbol* s) const {
- map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s);
- return (i != parsers.end()) ? &i->second : NULL;
- }
- ASTSymbol* sym(const string& s) {
- const const_iterator i = find(s);
- return ((i != end())
- ? i->second
- : insert(make_pair(s, new ASTSymbol(s))).first->second);
- }
-};
/// Closure (first-class function with captured lexical bindings)
struct ASTClosure : public AST {
ASTClosure(ASTTuple* p, AST* b) : prot(p), body(b), func(0) {}
- string str(CEnv& env) const { return "(fn)"; }
- AType* type(CEnv& cenv) {
- vector<AST*> texp(3);
- texp[0] = cenv.penv.sym("Fn");
- texp[1] = prot;
- texp[2] = body;
- return new AType(texp);
- }
+ bool operator==(const AST& rhs) const { return this == &rhs; }
+ string str() const { return "(fn)"; }
+ void constrain(TEnv& tenv) const;
+ void lift(CEnv& cenv);
Value* compile(CEnv& cenv);
- void lift(CEnv& cenv);
ASTTuple* const prot;
AST* const body;
- vector<const ASTSymbol*> bindings;
private:
Function* func;
};
-
+
/// Function call/application, e.g. "(func arg1 arg2)"
struct ASTCall : public ASTTuple {
ASTCall(const vector<AST*>& t) : ASTTuple(t) {}
- AType* type(CEnv& cenv) {
- AST* callee = tup[0];
- ASTSymbol* sym = dynamic_cast<ASTSymbol*>(tup[0]);
- if (sym) {
- AST** val = cenv.code.ref(sym);
- if (val)
- callee = *val;
- }
- ASTClosure* c = dynamic_cast<ASTClosure*>(callee);
- if (!c) throw TypeError("Call to non-closure");
- return c->body->type(cenv);
- }
- void lift(CEnv& cenv);
+ void constrain(TEnv& tenv) const;
+ void lift(CEnv& cenv);
Value* compile(CEnv& cenv);
};
/// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))"
struct ASTDefinition : public ASTCall {
ASTDefinition(const vector<AST*>& c) : ASTCall(c) {}
- AType* type(CEnv& cenv) { return tup[2]->type(cenv); }
+ void constrain(TEnv& tenv) const;
Value* compile(CEnv& cenv);
};
/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct ASTIf : public ASTCall {
ASTIf(const vector<AST*>& c) : ASTCall(c) {}
- AType* type(CEnv& cenv) {
- AType* cT = tup[1]->type(cenv);
- AType* tT = tup[2]->type(cenv);
- AType* eT = tup[3]->type(cenv);
- if (cT->ctype != Type::Int1Ty) throw TypeError("If condition is not a boolean");
- if (*tT != *eT) throw TypeError("If branches have different types");
- return tT;
- }
+ void constrain(TEnv& tenv) const;
Value* compile(CEnv& cenv);
};
@@ -370,26 +248,42 @@ struct ASTIf : public ASTCall {
struct ASTPrimitive : public ASTCall {
ASTPrimitive(const vector<AST*>& c, unsigned o, unsigned a=0)
: ASTCall(c), op(o), arg(a) {}
- AType* type(CEnv& cenv);
+ void constrain(TEnv& tenv) const;
Value* compile(CEnv& cenv);
unsigned op;
unsigned arg;
};
-AType*
-ASTTuple::type(CEnv& cenv)
-{
- vector<AST*> texp;
- FOREACH(vector<AST*>::const_iterator, p, tup)
- texp.push_back((*p)->type(cenv));
- return new AType(texp);
-}
-
/***************************************************************************
* Parser - S-Expressions (SExp) -> AST Nodes (AST) *
***************************************************************************/
+/// LLVM Operation
+struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; };
+
+typedef Op UD; // User Data argument for parse functions
+
+// Parse Time Environment (symbol table)
+struct PEnv : private map<const string, ASTSymbol*> {
+ typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function
+ struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; };
+ map<const ASTSymbol*, Parser> parsers;
+ void reg(const ASTSymbol* s, const Parser& p) {
+ parsers.insert(make_pair(s, p));
+ }
+ const Parser* parser(const ASTSymbol* s) const {
+ map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s);
+ return (i != parsers.end()) ? &i->second : NULL;
+ }
+ ASTSymbol* sym(const string& s) {
+ const const_iterator i = find(s);
+ return ((i != end())
+ ? i->second
+ : insert(make_pair(s, new ASTSymbol(s))).first->second);
+ }
+};
+
/// The fundamental parser method
static AST*
parseExpression(PEnv& penv, const SExp& exp)
@@ -445,39 +339,286 @@ static AST*
parsePrim(PEnv& penv, const list<SExp>& c, UD data)
{ return new ASTPrimitive(pmap(penv, c), data.op, data.arg); }
-static ASTTuple*
-parsePrototype(PEnv& penv, const SExp& e, UD)
- { return new ASTTuple(pmap(penv, e.list)); }
-
static AST*
parseFn(PEnv& penv, const list<SExp>& c, UD)
{
list<SExp>::const_iterator a = c.begin(); ++a;
return new ASTClosure(
- parsePrototype(penv, *a++, UD()),
+ new ASTTuple(pmap(penv, (*a++).list)),
parseExpression(penv, *a++));
}
/***************************************************************************
+ * Lexical Environment *
+ ***************************************************************************/
+
+template<typename K, typename V>
+struct Env : public list< map<K,V> > {
+ typedef map<K,V> Frame;
+ Env() : list<Frame>(1) {}
+ void push_front() { list<Frame>::push_front(Frame()); }
+ void def(const K& k, const V& v) {
+ if (this->front().find(k) != this->front().end())
+ throw SyntaxError("Redefinition");
+ this->front()[k] = v;
+ }
+ V* ref(const K& name) {
+ typename Frame::iterator s;
+ for (typename Env::iterator i = this->begin(); i != this->end(); ++i)
+ if ((s = i ->find(name)) != i->end())
+ return &s->second;
+ return 0;
+ }
+};
+
+
+/***************************************************************************
+ * Typing *
+ ***************************************************************************/
+
+struct TypeError : public Error { TypeError (const char* m) : Error(m) {} };
+
+/// Type-Time Environment
+struct TEnv {
+ TEnv(PEnv& p) : penv(p), varID(0) {}
+ typedef map<const AST*, AType*> Types;
+ typedef multimap<const AST*, AType*> Constraints;
+ AType* var() { return new AType(varID++); }
+ AType* type(const AST* ast) {
+ Types::iterator t = types.find(ast);
+ if (t != types.end())
+ return t->second;
+ else
+ return (types[ast] = var());
+ }
+ AType* named(const string& name) const {
+ Types::const_iterator i = namedTypes.find(penv.sym(name));
+ if (i == namedTypes.end()) throw TypeError("Unknown named type");
+ return i->second;
+ }
+ void name(const string& name, const Type* type) {
+ ASTSymbol* sym = penv.sym(name);
+ namedTypes[sym] = new AType(penv.sym(name), type);
+ }
+ void constrain(const AST* ast, AType* type) {
+ constraints.insert(make_pair(ast, type));
+ }
+ void unify();
+
+ PEnv& penv;
+ Types types;
+ Types namedTypes;
+ Constraints constraints;
+ unsigned varID;
+};
+
+#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End)
+
+void
+ASTTuple::constrain(TEnv& tenv) const
+{
+ vector<AST*> texp;
+ FOREACH(vector<AST*>::const_iterator, p, tup) {
+ (*p)->constrain(tenv);
+ AType* tvar = tenv.var();
+ texp.push_back(tvar);
+ tenv.constrain(tvar, tenv.type(*p));
+ }
+ tenv.constrain(this, new AType(texp));
+}
+
+void
+ASTClosure::constrain(TEnv& tenv) const
+{
+ prot->constrain(tenv);
+ body->constrain(tenv);
+ vector<AST*> texp(3);
+ texp[0] = tenv.penv.sym("Fn");
+ texp[1] = prot;
+ texp[2] = body;
+ AType* tvar = tenv.var();
+ tenv.constrain(texp[2], tvar);
+ tenv.constrain(this, tvar);
+}
+
+void
+ASTCall::constrain(TEnv& tenv) const
+{
+ ASTTuple::constrain(tenv);
+#if 0
+ AST* callee = tup[0];
+ ASTSymbol* sym = dynamic_cast<ASTSymbol*>(tup[0]);
+ if (sym) {
+ AST** val = tenv.code.ref(sym);
+ if (val)
+ callee = *val;
+ }
+ ASTClosure* c = dynamic_cast<ASTClosure*>(callee);
+ if (!c) throw TypeError("Call to non-closure");
+ tenv.contraints[this] = c->body->type(tenv);
+#endif
+}
+
+void
+ASTDefinition::constrain(TEnv& tenv) const
+{
+ FOREACH(vector<AST*>::const_iterator, p, tup)
+ (*p)->constrain(tenv);
+ AType* tvar = tenv.var();
+ tenv.constrain(tup[1], tvar);
+ tenv.constrain(tup[2], tvar);
+ tenv.constrain(this, tvar);
+}
+
+void
+ASTIf::constrain(TEnv& tenv) const
+{
+ FOREACH(vector<AST*>::const_iterator, p, tup)
+ (*p)->constrain(tenv);
+ AType* tvar = tenv.var();
+ tenv.constrain(tup[1], tenv.named("Bool"));
+ tenv.constrain(tup[2], tvar);
+ tenv.constrain(tup[3], tvar);
+ tenv.constrain(this, tvar);
+}
+
+void
+ASTPrimitive::constrain(TEnv& tenv) const
+{
+ FOREACH(vector<AST*>::const_iterator, p, tup)
+ (*p)->constrain(tenv);
+ if (OP_IS_A(op, Instruction::BinaryOps)) {
+ if (tup.size() <= 1) throw SyntaxError("Primitive call with 0 args");
+ AType* tvar = tenv.var();
+ for (size_t i = 1; i < tup.size(); ++i)
+ tenv.constrain(tup[i], tvar);
+ tenv.constrain(this, tvar);
+ } else if (op == Instruction::ICmp) {
+ if (tup.size() != 3) throw SyntaxError("Comparison call with != 2 args");
+ tenv.constrain(tup[1], tenv.type(tup[2]));
+ tenv.constrain(this, tenv.named("Bool"));
+ } else {
+ throw TypeError("Unknown primitive");
+ }
+}
+
+void
+TEnv::unify()
+{
+ typedef map<const AType*, AType*> Substitutions;
+
+ bool progress = false;
+ do {
+ progress = false;
+ //std::cout << "========" << endl;
+ Substitutions subst;
+ for (Constraints::iterator c = constraints.begin(); c != constraints.end();) {
+ Constraints::iterator next = c;
+ ++next;
+ const AST* o = c->first;
+ AType* t = c->second;
+ //std::cout << "Constraint: " << o->str() << " = " << t->str() << endl;
+ if (t->var) {
+ Types::iterator ot = types.find(o);
+ if (ot != types.end())
+ subst[t] = ot->second;
+ } else {
+ Types::iterator ot = types.find(o);
+ if (ot == types.end()) {
+ //std::cout << "Resolve: " << o->str() << endl;
+ types.insert(make_pair(o, t));
+ constraints.erase(c);
+ }
+ }
+ c = next;
+ }
+
+ for (Substitutions::iterator s = subst.begin(); s != subst.end(); ++s) {
+ for (Constraints::iterator c = constraints.begin(); c != constraints.end(); ++c) {
+ if (c->second == s->first) {
+ //std::cout << c->second->str() << " => " << s->second->str() << endl;
+ c->second = s->second;
+ progress = true;
+ }
+ }
+ }
+
+ } while (progress);
+
+ //std::cout << "======== Done unification" << endl;
+
+ constraints.clear();
+}
+
+
+/***************************************************************************
* Code Generation *
***************************************************************************/
struct CompileError : public Error { CompileError(const char* m) : Error(m) {} };
+class PEnv;
+
+/// Compile-Time Environment
+struct CEnv {
+ CEnv(PEnv& p, Module* m, const TargetData* target)
+ : penv(p), tenv(p), module(m), emp(module), fpm(&emp), symID(0)
+ {
+ // Set up the optimizer pipeline.
+ // Register info about how the target lays out data structures.
+ fpm.add(new TargetData(*target));
+ // Do simple "peephole" and bit-twiddling optimizations.
+ fpm.add(createInstructionCombiningPass());
+ // Reassociate expressions.
+ fpm.add(createReassociatePass());
+ // Eliminate Common SubExpressions.
+ fpm.add(createGVNPass());
+ // Simplify control flow graph (delete unreachable blocks, etc).
+ fpm.add(createCFGSimplificationPass());
+ }
+ string gensym(const char* base="_") {
+ ostringstream s; s << base << symID++; return s.str();
+ }
+ typedef Env<const AST*, AST*> Code;
+ typedef Env<const ASTSymbol*, Value*> Vals;
+
+ PEnv& penv;
+ TEnv tenv;
+ IRBuilder<> builder;
+ Module* module;
+ ExistingModuleProvider emp;
+ FunctionPassManager fpm;
+ unsigned symID;
+ Code code;
+ Vals vals;
+};
+
+#define LITERAL(CT, VT, NAME, COMPILED) \
+template<> Value* \
+ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
+template<> void \
+ASTLiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); }
+
+/// Literal template instantiations
+LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, true));
+LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val));
+LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false));
+
static Function*
compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT)
{
Function::LinkageTypes linkage = Function::ExternalLinkage;
- const vector<AST*>& texp = prot.type(cenv)->tup;
+ const vector<AST*>& texp = cenv.tenv.type(&prot)->tup;
vector<const Type*> cprot;
for (size_t i = 0; i < texp.size(); ++i) {
- const Type* t = texp[i]->type(cenv)->ctype;
+ const Type* t = cenv.tenv.type(texp[i])->ctype;
if (!t) throw CompileError("Function prototype contains NULL");
cprot.push_back(t);
}
+ if (!retT) throw CompileError("Function return value type is NULL");
FunctionType* fT = FunctionType::get(retT, cprot, false);
Function* f = Function::Create(fT, linkage, name, cenv.module);
@@ -489,17 +630,11 @@ compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type*
// Set argument names in generated code
Function::arg_iterator a = f->arg_begin();
for (size_t i = 0; i != prot.tup.size(); ++a, ++i)
- a->setName(prot.tup[i]->str(cenv));
+ a->setName(prot.tup[i]->str());
return f;
}
-AType*
-ASTSymbol::type(CEnv& cenv)
-{
- AST** t = cenv.code.ref(this);
- return t ? (*t)->type(cenv) : new AType(cenv.tID++);
-}
Value*
ASTSymbol::compile(CEnv& cenv)
@@ -552,7 +687,7 @@ ASTCall::lift(CEnv& cenv)
tup[i]->lift(cenv);
// Extend environment with bound and typed parameters
- cenv.code.push();
+ cenv.code.push_front();
if (c->prot->tup.size() != tup.size() - 1)
throw CompileError("Call to closure with mismatched arguments");
@@ -562,7 +697,7 @@ ASTCall::lift(CEnv& cenv)
// Lift callee closure
tup[0]->lift(cenv);
- cenv.code.pop();
+ cenv.code.pop_front();
}
Value*
@@ -576,9 +711,9 @@ ASTCall::compile(CEnv& cenv)
if (!c) throw CompileError("Call to non-closure");
Value* v = c->compile(cenv);
- if (!v) throw SyntaxError("Callee failed to compile");
+ if (!v) throw CompileError("Callee failed to compile");
Function* f = dynamic_cast<Function*>(c->compile(cenv));
- if (!f) throw SyntaxError("Callee compiled to non-function");
+ if (!f) throw CompileError("Callee compiled to non-function");
vector<Value*> params;
for (size_t i = 1; i < tup.size(); ++i)
@@ -603,25 +738,23 @@ ASTIf::compile(CEnv& cenv)
// Emit then value.
cenv.builder.SetInsertPoint(thenBB);
- Value* thenV = tup[2]->compile(cenv);
+ Value* thenV = tup[2]->compile(cenv); // Can change current block, so...
cenv.builder.CreateBr(mergeBB);
- // compile of 'Then' can change the current block, update thenBB
- thenBB = cenv.builder.GetInsertBlock();
+ thenBB = cenv.builder.GetInsertBlock(); // ... update thenBB afterwards
// Emit else block.
parent->getBasicBlockList().push_back(elseBB);
cenv.builder.SetInsertPoint(elseBB);
- Value* elseV = tup[3]->compile(cenv);
+ Value* elseV = tup[3]->compile(cenv); // Can change current block, so...
cenv.builder.CreateBr(mergeBB);
- // compile of 'Else' can change the current block, update elseBB
- elseBB = cenv.builder.GetInsertBlock();
+ elseBB = cenv.builder.GetInsertBlock(); // ... update elseBB afterwards
// Emit merge block.
parent->getBasicBlockList().push_back(mergeBB);
cenv.builder.SetInsertPoint(mergeBB);
- PHINode* pn = cenv.builder.CreatePHI(type(cenv)->ctype, "iftmp");
+ PHINode* pn = cenv.builder.CreatePHI(cenv.tenv.type(this)->ctype, "iftmp");
pn->addIncoming(thenV, thenBB);
pn->addIncoming(elseV, elseBB);
@@ -635,23 +768,16 @@ ASTClosure::lift(CEnv& cenv)
// Can't lift a closure with variable types (lift later when called)
for (size_t i = 0; i < prot->tup.size(); ++i)
- if (prot->tup[i]->type(cenv)->var)
+ if (cenv.tenv.type(prot->tup[i])->var)
return;
- cenv.code.push();
+ if (cenv.tenv.type(body)->var)
+ return;
+
+ cenv.code.push_front();
- ASTSymbol* sym = dynamic_cast<ASTSymbol*>(body);
- if (sym) {
- AST** obj = cenv.code.ref(sym);
- if (!obj) {
- std::cout << "UNDEFINED SYMBOL BODY\n";
- prot->tup.push_back(sym);
- bindings.push_back(sym);
- }
- }
-
// Write function declaration
- Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, body->type(cenv)->ctype);
+ Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(body)->ctype);
BasicBlock* bb = BasicBlock::Create("entry", f);
cenv.builder.SetInsertPoint(bb);
@@ -673,7 +799,7 @@ ASTClosure::lift(CEnv& cenv)
throw e;
}
- cenv.code.pop();
+ cenv.code.pop_front();
}
Value*
@@ -683,22 +809,6 @@ ASTClosure::compile(CEnv& cenv)
return func;
}
-#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End)
-
-AType*
-ASTPrimitive::type(CEnv& cenv)
-{
- if (OP_IS_A(op, Instruction::BinaryOps)) {
- if (tup.size() <= 1) throw SyntaxError("Primitive call with 0 args");
- return tup[1]->type(cenv);
- } else if (op == Instruction::ICmp) {
- if (tup.size() != 3) throw SyntaxError("Comparison call with != 2 args");
- return new AType("Bool", Type::Int1Ty);
- } else {
- throw CompileError("Unknown primitive");
- }
-}
-
Value*
ASTPrimitive::compile(CEnv& cenv)
{
@@ -723,7 +833,7 @@ ASTPrimitive::compile(CEnv& cenv)
return val;
}
} else if (op == Instruction::ICmp) {
- bool isInt = tup[1]->type(cenv)->str(cenv) == "(Int)";
+ bool isInt = cenv.tenv.type(tup[1])->str() == "(Int)";
if (isInt) {
return cenv.builder.CreateICmp((CmpInst::Predicate)arg, a, b);
} else {
@@ -778,6 +888,10 @@ main()
cenv.code.def(penv.sym("true"), new ASTLiteral<bool>(true));
cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false));
+ cenv.tenv.name("Bool", Type::Int1Ty);
+ cenv.tenv.name("Int", Type::Int32Ty);
+ cenv.tenv.name("Float", Type::FloatTy);
+
while (1) {
std::cout << "(=>) ";
std::cout.flush();
@@ -786,9 +900,13 @@ main()
break;
try {
- AST* body = parseExpression(penv, exp);
+ AST* body = parseExpression(penv, exp);
+
+ body->constrain(cenv.tenv);
+ cenv.tenv.unify();
+
ASTTuple* prot = new ASTTuple();
- AType* bodyT = body->type(cenv);
+ AType* bodyT = cenv.tenv.type(body);
if (!bodyT) throw TypeError("REPL call to untyped body");
if (bodyT->var) throw TypeError("REPL call to variable typed body");
@@ -825,7 +943,7 @@ main()
Value* val = body->compile(cenv);
std::cout << val;
}
- std::cout << " : " << body->type(cenv)->str(cenv) << endl;
+ std::cout << " : " << cenv.tenv.type(body)->str() << endl;
} catch (Error e) {
std::cerr << "Error: " << e.what() << endl;