aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Robillard <d@drobilla.net>2009-01-26 07:37:38 +0000
committerDavid Robillard <d@drobilla.net>2009-01-26 07:37:38 +0000
commit6671a54283fc7d9323fc14c9feee525f01b2821d (patch)
tree23374308cbc59f3d89d390c9c9cb45fbdf66998c
parent022f55e2ab4da12ae45321c7f2cca71b66c417a4 (diff)
downloadresp-6671a54283fc7d9323fc14c9feee525f01b2821d.tar.gz
resp-6671a54283fc7d9323fc14c9feee525f01b2821d.tar.bz2
resp-6671a54283fc7d9323fc14c9feee525f01b2821d.zip
Somewhat functional type inference.
git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@16 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r--ll.cpp179
1 files changed, 104 insertions, 75 deletions
diff --git a/ll.cpp b/ll.cpp
index cc971a5..0dfa04e 100644
--- a/ll.cpp
+++ b/ll.cpp
@@ -18,6 +18,7 @@
* along with This program. If not, see <http://www.gnu.org/licenses/>.
*/
+#include <stdarg.h>
#include <iostream>
#include <list>
#include <map>
@@ -176,7 +177,7 @@ struct ASTTuple : public AST {
/// Type Expression ::= (TName TExpr*) | ?Num
struct AType : public ASTTuple {
AType(const vector<AST*>& t) : ASTTuple(t), var(false), ctype(0) {}
- AType(unsigned i) : var(true), ctype(0), id(id) {}
+ AType(unsigned i) : var(true), ctype(0), id(i) {}
AType(ASTSymbol* n, const Type* t) : var(false), ctype(t) {
tup.push_back(n);
}
@@ -189,6 +190,15 @@ struct AType : public ASTTuple {
}
void constrain(TEnv& tenv) const {}
Value* compile(CEnv& cenv) { return NULL; }
+ bool concrete() const {
+ if (var) return false;
+ FOREACH(vector<AST*>::const_iterator, t, tup) {
+ AType* kid = dynamic_cast<AType*>(*t);
+ if (kid && !kid->concrete())
+ return false;
+ }
+ return true;
+ }
bool var;
const Type* ctype;
unsigned id;
@@ -212,7 +222,7 @@ struct ASTLiteral : public AST {
struct ASTClosure : public AST {
ASTClosure(ASTTuple* p, AST* b) : prot(p), body(b), func(0) {}
bool operator==(const AST& rhs) const { return this == &rhs; }
- string str() const { return "(fn)"; }
+ string str() const { ostringstream s; s << this; return s.str(); }
void constrain(TEnv& tenv) const;
void lift(CEnv& cenv);
Value* compile(CEnv& cenv);
@@ -381,16 +391,19 @@ struct TypeError : public Error { TypeError (const char* m) : Error(m) {} };
/// Type-Time Environment
struct TEnv {
- TEnv(PEnv& p) : penv(p), varID(0) {}
+ TEnv(PEnv& p) : penv(p), varID(1) {}
typedef map<const AST*, AType*> Types;
typedef multimap<const AST*, AType*> Constraints;
AType* var() { return new AType(varID++); }
AType* type(const AST* ast) {
Types::iterator t = types.find(ast);
- if (t != types.end())
+ if (t != types.end()) {
return t->second;
- else
- return (types[ast] = var());
+ } else {
+ AType* tvar = var();
+ constrain(ast, tvar);
+ return tvar;
+ }
}
AType* named(const string& name) const {
Types::const_iterator i = namedTypes.find(penv.sym(name));
@@ -404,8 +417,7 @@ struct TEnv {
void constrain(const AST* ast, AType* type) {
constraints.insert(make_pair(ast, type));
}
- void unify();
-
+ AType* unify(AST* root);
PEnv& penv;
Types types;
Types namedTypes;
@@ -415,6 +427,18 @@ struct TEnv {
#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End)
+vector<AST*>
+tuple(AST* ast, ...)
+{
+ vector<AST*> tup(1, ast);
+ va_list args;
+ va_start(args, ast);
+ for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
+ tup.push_back(a);
+ va_end(args);
+ return tup;
+}
+
void
ASTTuple::constrain(TEnv& tenv) const
{
@@ -423,7 +447,7 @@ ASTTuple::constrain(TEnv& tenv) const
(*p)->constrain(tenv);
AType* tvar = tenv.var();
texp.push_back(tvar);
- tenv.constrain(tvar, tenv.type(*p));
+ tenv.constrain(*p, tvar);
}
tenv.constrain(this, new AType(texp));
}
@@ -433,31 +457,21 @@ ASTClosure::constrain(TEnv& tenv) const
{
prot->constrain(tenv);
body->constrain(tenv);
- vector<AST*> texp(3);
- texp[0] = tenv.penv.sym("Fn");
- texp[1] = prot;
- texp[2] = body;
- AType* tvar = tenv.var();
- tenv.constrain(texp[2], tvar);
- tenv.constrain(this, tvar);
+ AType* bodyT = tenv.var();
+ tenv.constrain(body, bodyT);
+ tenv.constrain(this, new AType(tuple(
+ tenv.penv.sym("Fn"), tenv.type(prot), bodyT, 0)));
}
void
ASTCall::constrain(TEnv& tenv) const
{
- ASTTuple::constrain(tenv);
-#if 0
- AST* callee = tup[0];
- ASTSymbol* sym = dynamic_cast<ASTSymbol*>(tup[0]);
- if (sym) {
- AST** val = tenv.code.ref(sym);
- if (val)
- callee = *val;
- }
- ASTClosure* c = dynamic_cast<ASTClosure*>(callee);
- if (!c) throw TypeError("Call to non-closure");
- tenv.contraints[this] = c->body->type(tenv);
-#endif
+ FOREACH(vector<AST*>::const_iterator, p, tup)
+ (*p)->constrain(tenv);
+ AType* retT = tenv.var();
+ vector<AST*> texp = tuple(tenv.penv.sym("Fn"), tenv.var(), retT, NULL);
+ tenv.constrain(new AType(texp), tenv.var());
+ tenv.constrain(this, retT);
}
void
@@ -503,52 +517,71 @@ ASTPrimitive::constrain(TEnv& tenv) const
}
}
-void
-TEnv::unify()
+static bool
+substitute(ASTTuple* tup, AST* from, AST* to)
{
- typedef map<const AType*, AType*> Substitutions;
-
bool progress = false;
- do {
- progress = false;
- //std::cout << "========" << endl;
- Substitutions subst;
+ for (size_t i = 0; i < tup->tup.size(); ++i) {
+ if (*tup->tup[i] == *from) {
+ tup->tup[i] = to;
+ progress = true;
+ }
+ }
+ return progress;
+}
+
+AType*
+TEnv::unify(AST* root)
+{
+ root->constrain(*this);
+ assert(constraints.find(root) != constraints.end());
+ //constrain(root, var());
+ typedef map<AType*, AType*> Substitutions;
+ Substitutions subst;
+ for (bool progress = true; progress; progress = false) {
+ //std::cout << "==== " << constraints.size() << endl;
for (Constraints::iterator c = constraints.begin(); c != constraints.end();) {
Constraints::iterator next = c;
++next;
const AST* o = c->first;
AType* t = c->second;
- //std::cout << "Constraint: " << o->str() << " = " << t->str() << endl;
- if (t->var) {
- Types::iterator ot = types.find(o);
- if (ot != types.end())
- subst[t] = ot->second;
- } else {
+ //std::cout << "Constr : " << o->str() << " = " << t->str() << endl;
+ if (t->concrete()) {
Types::iterator ot = types.find(o);
if (ot == types.end()) {
- //std::cout << "Resolve: " << o->str() << endl;
+ //std::cout << "Resolv : " << o->str() << endl;
types.insert(make_pair(o, t));
- constraints.erase(c);
+ //constraints.erase(c);
+ progress = true;
+ }
+ } else {
+ Types::iterator ot = types.find(o);
+ if (ot != types.end()) {
+ subst[t] = ot->second;
+ progress = true;
}
}
c = next;
}
-
for (Substitutions::iterator s = subst.begin(); s != subst.end(); ++s) {
+ //std::cout << "Subst : " << s->first->str() << " => " << s->second->str() << endl;
for (Constraints::iterator c = constraints.begin(); c != constraints.end(); ++c) {
- if (c->second == s->first) {
- //std::cout << c->second->str() << " => " << s->second->str() << endl;
+ AType* objT = c->second;
+ if (objT == s->first && objT != s->second) {
c->second = s->second;
progress = true;
}
+ progress = substitute(c->second, s->first, s->second) || progress;
}
}
-
- } while (progress);
-
- //std::cout << "======== Done unification" << endl;
-
- constraints.clear();
+ }
+ //std::cout << "======== Done Unifying Types " << constraints.size() << endl;
+ Constraints::iterator i = constraints.find(root);
+ if (i != constraints.end()) {
+ types[root] = i->second;
+ return i->second;
+ }
+ return NULL;
}
@@ -594,16 +627,16 @@ struct CEnv {
Vals vals;
};
-#define LITERAL(CT, VT, NAME, COMPILED) \
+#define LITERAL(CT, NAME, COMPILED) \
template<> Value* \
ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
template<> void \
ASTLiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); }
/// Literal template instantiations
-LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, true));
-LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val));
-LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false));
+LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true));
+LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val));
+LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false));
static Function*
compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT)
@@ -893,29 +926,25 @@ main()
cenv.tenv.name("Float", Type::FloatTy);
while (1) {
- std::cout << "(=>) ";
+ std::cout << "(= ";
std::cout.flush();
SExp exp = readExpression(std::cin);
if (exp.type == SExp::LIST && exp.list.empty())
break;
try {
- AST* body = parseExpression(penv, exp);
-
- body->constrain(cenv.tenv);
- cenv.tenv.unify();
-
- ASTTuple* prot = new ASTTuple();
- AType* bodyT = cenv.tenv.type(body);
+ AST* body = parseExpression(penv, exp);
+ AType* bodyT = cenv.tenv.unify(body);
if (!bodyT) throw TypeError("REPL call to untyped body");
if (bodyT->var) throw TypeError("REPL call to variable typed body");
body->lift(cenv);
-
+
if (bodyT->ctype) {
// Create anonymous function to insert code into.
- Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype);
+ ASTTuple* prot = new ASTTuple();
+ Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype);
BasicBlock* bb = BasicBlock::Create("entry", f);
cenv.builder.SetInsertPoint(bb);
@@ -931,26 +960,26 @@ main()
void* fp = engine->getPointerToFunction(f);
if (bodyT->ctype == Type::Int32Ty)
- std::cout << ((int32_t (*)())fp)();
+ std::cout << " " <<((int32_t (*)())fp)();
else if (bodyT->ctype == Type::FloatTy)
- std::cout << ((float (*)())fp)();
+ std::cout << " " <<((float (*)())fp)();
else if (bodyT->ctype == Type::Int1Ty)
- std::cout << ((bool (*)())fp)();
+ std::cout << " " << ((bool (*)())fp)();
else
- std::cout << "?";
+ std::cout << " ?";
} else {
Value* val = body->compile(cenv);
- std::cout << val;
+ std::cout << " " << val;
}
- std::cout << " : " << cenv.tenv.type(body)->str() << endl;
+ std::cout << " : " << cenv.tenv.type(body)->str() << ")" << endl;
} catch (Error e) {
std::cerr << "Error: " << e.what() << endl;
}
}
- std::cout << "Generated code:" << endl;
+ std::cout << endl << "Generated code:" << endl;
module->dump();
return 0;
}