aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Robillard <d@drobilla.net>2009-03-14 21:28:01 +0000
committerDavid Robillard <d@drobilla.net>2009-03-14 21:28:01 +0000
commit053a6e588b143bdd765113cdcd53ae7ef39c1a6c (patch)
treed541eceaff86e14fd618bd999c41a2fb33045cb5
parent6bed972454d60c503b760bcac92c4e36ba95520d (diff)
downloadresp-053a6e588b143bdd765113cdcd53ae7ef39c1a6c.tar.gz
resp-053a6e588b143bdd765113cdcd53ae7ef39c1a6c.tar.bz2
resp-053a6e588b143bdd765113cdcd53ae7ef39c1a6c.zip
Lexical addressing, work towards true parametric polymorphism.
git-svn-id: http://svn.drobilla.net/resp/tuplr@88 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r--llvm.cpp179
-rw-r--r--tuplr.cpp12
-rw-r--r--tuplr.hpp210
-rw-r--r--typing.cpp231
4 files changed, 432 insertions, 200 deletions
diff --git a/llvm.cpp b/llvm.cpp
index baaeb9e..32577f9 100644
--- a/llvm.cpp
+++ b/llvm.cpp
@@ -46,10 +46,11 @@ struct LLVMEngine {
};
static const Type*
-lltype(AType* t)
+lltype(const AType* t)
{
switch (t->kind) {
case AType::VAR:
+ throw Error((format("non-compilable type `%1%'") % t->str()).str(), t->loc);
return NULL;
case AType::PRIM:
if (t->at(0)->str() == "Bool") return Type::Int1Ty;
@@ -117,7 +118,7 @@ CValue
CEnv::compile(AST* obj)
{
CValue* v = vals.ref(obj);
- return (v) ? *v : vals.def(obj, obj->compile(*this));
+ return (v && *v) ? *v : vals.def(obj, obj->compile(*this));
}
void
@@ -138,7 +139,7 @@ CEnv::write(std::ostream& os)
template<> CValue \
ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
template<> void \
-ALiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); }
+ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) { c.constrain(tenv, this, tenv.named(NAME)); }
/// Literal template instantiations
LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true))
@@ -146,14 +147,15 @@ LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val))
LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false))
static Function*
-compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& prot,
+compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& protT,
const vector<string> argNames=vector<string>())
{
Function::LinkageTypes linkage = Function::ExternalLinkage;
vector<const Type*> cprot;
- for (size_t i = 0; i < prot.size(); ++i) {
- AType* at = cenv.tenv.type(prot.at(i));
+ for (size_t i = 0; i < protT.size(); ++i) {
+ AType* at = dynamic_cast<AType*>(protT.at(i));
+ if (!at) throw Error("function parameter type isn't");
if (!lltype(at)) throw Error("function parameter is untyped");
cprot.push_back(lltype(at));
}
@@ -170,11 +172,8 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu
// Set argument names in generated code
Function::arg_iterator a = f->arg_begin();
if (!argNames.empty())
- for (size_t i = 0; i != prot.size(); ++a, ++i)
+ for (size_t i = 0; i != protT.size(); ++a, ++i)
a->setName(argNames.at(i));
- else
- for (size_t i = 0; i != prot.size(); ++a, ++i)
- a->setName(prot.at(i)->str());
BasicBlock* bb = BasicBlock::Create("entry", f);
llengine(cenv)->builder.SetInsertPoint(bb);
@@ -187,53 +186,110 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu
* Code Generation *
***************************************************************************/
-void
-ASymbol::lift(CEnv& cenv)
-{
- if (!cenv.code.ref(this))
- throw Error((string("undefined symbol `") + cppstr + "'").c_str(), loc);
-}
-
CValue
ASymbol::compile(CEnv& cenv)
{
- return cenv.compile(*cenv.code.ref(this));
+ return cenv.vals.ref(this);
}
void
AClosure::lift(CEnv& cenv)
{
- AType* type = cenv.tenv.type(this);
+ AType* type = cenv.type(this);
if (!type->concrete() || funcs.find(type))
return;
- cenv.push();
-
// Write function declaration
string name = this->name == "" ? cenv.gensym("_fn") : this->name;
- Function* f = compileFunction(cenv, name, lltype(cenv.tenv.type(at(2))), *prot());
+ ATuple* protT = dynamic_cast<ATuple*>(type->at(1));
+ assert(protT);
+ Function* f = compileFunction(cenv, name,
+ lltype(dynamic_cast<AType*>(type->at(type->size() - 1))),
+ *protT);
+
+ cenv.push();
+ Subst oldSubst = cenv.tsubst;
+ cenv.tsubst = Subst::compose(cenv.tsubst, *subst);
// Bind argument values in CEnv
vector<Value*> args;
const_iterator p = prot()->begin();
- for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
+ size_t i = 0;
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) {
+ cenv.tenv.def(*p, dynamic_cast<AType*>(protT->at(i++)));
cenv.vals.def(dynamic_cast<ASymbol*>(*p), &*a);
+ }
// Write function body
try {
- cenv.precompile(this, f); // Define our value first for recursion
+ // Define value first for recursion
+ cenv.precompile(this, f);
+ funcs.push_back(make_pair(type, f));
+
CValue retVal = cenv.compile(at(2));
llengine(cenv)->builder.CreateRet(LLVal(retVal)); // Finish function
cenv.optimise(LLFunc(f));
- funcs.push_back(make_pair(type, f));
+
} catch (Error& e) {
f->eraseFromParent(); // Error reading body, remove function
+ cenv.pop();
throw e;
}
-
+ cenv.tsubst = oldSubst;
cenv.pop();
}
+void
+AClosure::liftPoly(CEnv& cenv, const vector<AType*>& argsT)
+{
+ if (type->concrete())
+ return;
+
+ throw Error("No polymorphism");
+
+#if 0
+ //Subst tsubst;
+ assert(argsT.size() == prot()->size());
+ for (size_t i = 0; i < argsT.size(); ++i) {
+ cenv.err << " " << argsT.at(i)->str();
+ //tsubst[*cenv.tenv.ref(prot()->at(i))] = argsT.at(i);
+ }
+ cenv.err << endl;
+#endif
+}
+
+CValue
+AClosure::compile(CEnv& cenv)
+{
+ /*
+ cenv.err << "***********************************************" << endl;
+ cenv.err << cenv.type(this) << endl;
+
+ cenv.err << "COMPILING FOR TYPE:";
+ Subst tsubst;
+ assert(cenv.code.front().size() == prot()->size());
+ for (size_t i = 0; i < cenv.code.front().size(); ++i) {
+ cenv.err << " (" << cenv.type(prot()->at(i))->str()
+ << " -> " << cenv.type(cenv.code.front().at(i).second)->str() << ")";
+ tsubst[cenv.tenv.types[prot()->at(i)]] =
+ cenv.type(cenv.code.front().at(i).second);
+ }
+ cenv.err << endl;
+
+ Subst subst = Subst::compose(tsubst, cenv.tsubst);
+ AType* concreteType = subst.apply(type);
+ if (!concreteType->concrete())
+ throw Error("compiled function has non-concrete type", loc);
+
+ cenv.err << "*********** CONCRETE TYPE: " << concreteType->str() << endl;
+ */
+
+ //CValue ret = funcs.find(concreteType);
+ //cenv.err << "VALUE FOR TYPE " << concreteType->str() << " : " << ret << endl;
+ //return ret;
+ return NULL;
+}
+
template<typename T>
T
checked_cast(AST* ast)
@@ -250,27 +306,23 @@ AST*
maybeLookup(CEnv& cenv, AST* ast)
{
ASymbol* s = dynamic_cast<ASymbol*>(ast);
- if (s) {
- AST** val = cenv.code.ref(s);
- if (val) return *val;
- }
+ if (s)
+ return cenv.code.deref(s->addr);
return ast;
}
-CValue
-AClosure::compile(CEnv& cenv)
-{
- return funcs.find(cenv.tenv.type(this));
-}
-
void
ACall::lift(CEnv& cenv)
{
AClosure* c = dynamic_cast<AClosure*>(maybeLookup(cenv, at(0)));
+ vector<AType*> argsT;
+
// Lift arguments
- for (size_t i = 1; i < size(); ++i)
+ for (size_t i = 1; i < size(); ++i) {
at(i)->lift(cenv);
+ argsT.push_back(cenv.type(at(i)));
+ }
if (!c) return; // Primitive
@@ -284,15 +336,30 @@ ACall::lift(CEnv& cenv)
for (size_t i = 1; i < size(); ++i)
cenv.code.def(checked_cast<ASymbol*>(c->prot()->at(i-1)), at(i));
- c->lift(cenv); // Lift called closure
+ c->liftPoly(cenv, argsT); // Lift called closure
cenv.pop(); // Restore environment
}
CValue
ACall::compile(CEnv& cenv)
{
- AST* c = maybeLookup(cenv, at(0));
- Function* f = dynamic_cast<Function*>(LLVal(cenv.compile(c)));
+ AClosure* c = dynamic_cast<AClosure*>(maybeLookup(cenv, at(0)));
+
+ if (!c) return NULL; // Primitive
+
+ if (c->prot()->size() < size() - 1)
+ throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc);
+ if (c->prot()->size() > size() - 1)
+ throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc);
+
+ AType* protT = new AType(loc, NULL);
+ for (size_t i = 1; i < size(); ++i)
+ protT->push_back(cenv.type(at(i)));
+
+ AType* polyT = c->type;
+ AType* fnT = new AType(loc, cenv.penv.sym("Fn"), protT, polyT->at(2), 0);
+
+ Function* f = (Function*)c->funcs.find(fnT);
if (!f) throw Error("callee failed to compile", loc);
vector<Value*> params(size() - 1);
@@ -305,7 +372,7 @@ ACall::compile(CEnv& cenv)
void
ADefinition::lift(CEnv& cenv)
{
- if (cenv.code.ref(checked_cast<ASymbol*>(at(1))))
+ if (cenv.code.lookup(checked_cast<ASymbol*>(at(1))))
throw Error(string("`") + at(1)->str() + "' redefined", loc);
cenv.code.def((ASymbol*)at(1), at(2)); // Define first for recursion
at(2)->lift(cenv);
@@ -353,7 +420,7 @@ AIf::compile(CEnv& cenv)
// Emit merge block (Phi node)
parent->getBasicBlockList().push_back(mergeBB);
llengine(cenv)->builder.SetInsertPoint(mergeBB);
- PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.tenv.type(this)), "ifval");
+ PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.type(this)), "ifval");
FOREACH(Branches::iterator, i, branches)
pn->addIncoming(i->first, i->second);
@@ -366,7 +433,7 @@ APrimitive::compile(CEnv& cenv)
{
Value* a = LLVal(cenv.compile(at(1)));
Value* b = LLVal(cenv.compile(at(2)));
- bool isInt = cenv.tenv.type(at(1))->str() == "Int";
+ bool isInt = cenv.type(at(1))->str() == "Int";
const string n = dynamic_cast<ASymbol*>(at(0))->str();
// Binary arithmetic operations
@@ -407,9 +474,9 @@ APrimitive::compile(CEnv& cenv)
AType*
AConsCall::functionType(CEnv& cenv)
{
- ATuple* protTypes = new ATuple(loc, cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), 0);
+ ATuple* protTypes = new ATuple(loc, cenv.type(at(1)), cenv.type(at(2)), 0);
AType* cellType = new AType(loc,
- cenv.penv.sym("Pair"), cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), 0);
+ cenv.penv.sym("Pair"), cenv.type(at(1)), cenv.type(at(2)), 0);
return new AType(at(0)->loc, cenv.penv.sym("Fn"), protTypes, cellType, 0);
}
@@ -427,7 +494,7 @@ AConsCall::lift(CEnv& cenv)
vector<const Type*> types;
size_t sz = 0;
for (size_t i = 1; i < size(); ++i) {
- const Type* t = lltype(cenv.tenv.type(at(i)));
+ const Type* t = lltype(cenv.type(at(i)));
types.push_back(t);
sz += t->getPrimitiveSizeInBits();
}
@@ -520,15 +587,16 @@ eval(CEnv& cenv, const string& name, istream& is)
list< pair<SExp, AST*> > exprs;
Cursor cursor(name);
try {
+ Constraints c;
while (true) {
SExp exp = readExpression(cursor, is);
if (exp.type == SExp::LIST && exp.list.empty())
break;
result = cenv.penv.parse(exp); // Parse input
- result->constrain(cenv.tenv); // Constrain types
- cenv.tenv.solve(); // Solve and apply type constraints
- resultType = cenv.tenv.type(result);
+ result->constrain(cenv.tenv, c); // Constrain types
+ cenv.tsubst = TEnv::unify(c); // Solve type constraints
+ resultType = cenv.type(result);
result->lift(cenv); // Lift functions
exprs.push_back(make_pair(exp, result));
}
@@ -562,24 +630,27 @@ eval(CEnv& cenv, const string& name, istream& is)
int
repl(CEnv& cenv)
{
+ Constraints c;
while (1) {
cenv.out << "() ";
cenv.out.flush();
Cursor cursor("(stdin)");
try {
+
SExp exp = readExpression(cursor, std::cin);
if (exp.type == SExp::LIST && exp.list.empty())
break;
AST* body = cenv.penv.parse(exp); // Parse input
- body->constrain(cenv.tenv); // Constrain types
- cenv.tenv.solve(); // Solve and apply type constraints
+ body->constrain(cenv.tenv, c); // Constrain types
+
+ cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints
- AType* bodyT = cenv.tenv.type(body);
+ AType* bodyT = cenv.type(body);
if (!bodyT) throw Error("call to untyped body", cursor);
body->lift(cenv);
-
+
if (lltype(bodyT)) {
// Create anonymous function to insert code into
Function* f = compileFunction(cenv, cenv.gensym("_repl"), lltype(bodyT), ATuple());
@@ -595,7 +666,7 @@ repl(CEnv& cenv)
} else {
cenv.out << "; " << cenv.compile(body);
}
- cenv.out << " : " << cenv.tenv.type(body) << endl;
+ cenv.out << " : " << cenv.type(body) << endl;
} catch (Error& e) {
cenv.err << e.what() << endl;
}
diff --git a/tuplr.cpp b/tuplr.cpp
index bfd6192..3429a9f 100644
--- a/tuplr.cpp
+++ b/tuplr.cpp
@@ -123,10 +123,15 @@ parseLiteral(PEnv& penv, const SExp& exp, void* arg)
inline AST*
parseFn(PEnv& penv, const SExp& exp, void* arg)
{
+ if (exp.list.size() < 2)
+ throw Error("Missing function parameters and body", exp.loc);
+ else if (exp.list.size() < 3)
+ throw Error("Missing function body", exp.loc);
SExp::List::const_iterator a = exp.list.begin(); ++a;
- return new AClosure(exp.loc, penv.sym("fn"),
- new ATuple(penv.parseTuple(*a++)),
- penv.parse(*a++));
+ AClosure* ret = new AClosure(exp.loc, penv.sym("fn"), new ATuple(penv.parseTuple(*a++)));
+ while (a != exp.list.end())
+ ret->push_back(penv.parse(*a++));
+ return ret;
}
@@ -200,6 +205,7 @@ main(int argc, char** argv)
initLang(penv, tenv);
CEnv* cenv = newCenv(penv, tenv);
+ cenv->push();
map<string,string> args;
list<string> files;
diff --git a/tuplr.hpp b/tuplr.hpp
index a16072d..4d714d4 100644
--- a/tuplr.hpp
+++ b/tuplr.hpp
@@ -67,25 +67,53 @@ struct Exp {
List list;
};
+/// Lexical Address
+struct LAddr {
+ LAddr(unsigned u=0, unsigned o=0) : up(u), over(o) {}
+ operator bool() const { return !(up == 0 && over == 0); }
+ unsigned up, over;
+};
+
+inline ostream& operator<<(ostream& out, const LAddr& addr) {
+ out << addr.up << ":" << addr.over;
+ return out;
+}
+
/// Generic Lexical Environment
template<typename K, typename V>
-struct Env : public list< map<K,V> > {
- typedef map<K,V> Frame;
+struct Env : public list< vector< pair<K,V> > > {
+ typedef vector< pair<K,V> > Frame;
Env() : list<Frame>(1) {}
- void push() { list<Frame>::push_front(Frame()); }
- void pop() { assert(!this->empty()); list<Frame>::pop_front(); }
+ void push(Frame f=Frame()) { list<Frame>::push_front(f); }
+ void pop() { assert(!this->empty()); list<Frame>::pop_front(); }
const V& def(const K& k, const V& v) {
- typename Frame::iterator existing = this->front().find(k);
- if (existing != this->front().end() && existing->second != v)
- throw Error("redefinition");
- return (this->front()[k] = v);
+ for (typename Frame::iterator b = this->begin()->begin(); b != this->begin()->end(); ++b)
+ if (b->first == k)
+ return (b->second = v);
+ this->front().push_back(make_pair(k, v));
+ return v;
+ }
+ V* ref(const K& key) {
+ for (typename Env::iterator f = this->begin(); f != this->end(); ++f)
+ for (typename Frame::iterator b = f->begin(); b != f->end(); ++b)
+ if (b->first == key)
+ return &b->second;
+ return NULL;
}
- V* ref(const K& name) {
- typename Frame::iterator s;
- for (typename Env::iterator i = this->begin(); i != this->end(); ++i)
- if ((s = i ->find(name)) != i->end())
- return &s->second;
- return 0;
+ LAddr lookup(const K& key) const {
+ unsigned up = 0;
+ for (typename Env::const_iterator f = this->begin(); f != this->end(); ++f, ++up)
+ for (unsigned over = 0; over < f->size(); ++over)
+ if ((*f)[over].first == key)
+ return LAddr(up + 1, over + 1);
+ return LAddr();
+ }
+ V& deref(LAddr addr) {
+ assert(addr);
+ typename Env::iterator f = this->begin();
+ for (unsigned u = 1; u < addr.up; ++u, ++f) { assert(f != this->end()); }
+ assert(f->size() > addr.over - 1);
+ return (*f)[addr.over - 1].second;
}
};
@@ -112,10 +140,13 @@ typedef void* CEngine; ///< Compiler Engine (opaque)
* Abstract Syntax Tree *
***************************************************************************/
-struct TEnv; ///< Type-Time Environment
-struct CEnv; ///< Compile-Time Environment
-
+struct Constraint; ///< Type Constraint
+struct TEnv; ///< Type-Time Environment
+struct CEnv; ///< Compile-Time Environment
struct AST;
+struct Constraints;
+struct Subst;
+
extern ostream& operator<<(ostream& out, const AST* ast);
/// Base class for all AST nodes
@@ -124,7 +155,7 @@ struct AST {
virtual ~AST() {}
virtual bool operator==(const AST& o) const = 0;
virtual bool contains(const AST* child) const { return false; }
- virtual void constrain(TEnv& tenv) const {}
+ virtual void constrain(TEnv& tenv, Constraints& c) {}
virtual void lift(CEnv& cenv) {}
string str() const { ostringstream ss; ss << this; return ss.str(); }
Cursor loc;
@@ -141,7 +172,7 @@ struct ALiteral : public AST {
const ALiteral<VT>* r = dynamic_cast<const ALiteral<VT>*>(&rhs);
return (r && (val == r->val));
}
- void constrain(TEnv& tenv) const;
+ void constrain(TEnv& tenv, Constraints& c);
CValue compile(CEnv& cenv);
const VT val;
};
@@ -149,13 +180,16 @@ struct ALiteral : public AST {
/// Symbol, e.g. "a"
struct ASymbol : public AST {
bool operator==(const AST& rhs) const { return this == &rhs; }
- void lift(CEnv& cenv);
+ void lookup(TEnv& tenv);
+ void constrain(TEnv& tenv, Constraints& c);
CValue compile(CEnv& cenv);
+ LAddr addr;
private:
friend class PEnv;
ASymbol(const string& s, Cursor c) : AST(c), cppstr(s) {}
friend ostream& operator<<(ostream&, const AST*);
const string cppstr;
+
};
/// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)"
@@ -163,6 +197,8 @@ struct ATuple : public AST, public vector<AST*> {
ATuple(const vector<AST*>& t=vector<AST*>(), Cursor c=Cursor()) : AST(c), vector<AST*>(t) {}
ATuple(size_t size, Cursor c) : AST(c), vector<AST*>(size) {}
ATuple(Cursor c, AST* ast, ...) : AST(c) {
+ if (!ast)
+ return;
va_list args; va_start(args, ast);
push_back(ast);
for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
@@ -189,23 +225,23 @@ struct ATuple : public AST, public vector<AST*> {
return true;
return false;
}
- void constrain(TEnv& tenv) const;
+ void constrain(TEnv& tenv, Constraints& c);
CValue compile(CEnv& cenv) { throw Error("tuple compiled"); }
};
/// Type Expression, e.g. "Int", "(Fn (Int Int) Float)"
struct AType : public ATuple {
- AType(unsigned i, Cursor c=Cursor()) : ATuple(0, c), kind(VAR), id(i) {}
+ AType(unsigned i, LAddr a, Cursor c=Cursor()) : ATuple(0, c), kind(VAR), addr(a), id(i) {}
AType(ASymbol* s) : ATuple(0, s->loc), kind(PRIM), id(0) { push_back(s); }
AType(const ATuple& t, Cursor c) : ATuple(t, c), kind(EXPR), id(0) {}
AType(Cursor c, AST* ast, ...) : ATuple(0, c), kind(EXPR), id(0) {
va_list args; va_start(args, ast);
+ if (!ast) return;
push_back(ast);
for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
push_back(a);
va_end(args);
}
- void constrain(TEnv& tenv) const {}
CValue compile(CEnv& cenv) { return NULL; }
bool var() const { return kind == VAR; }
bool concrete() const {
@@ -234,6 +270,7 @@ struct AType : public ATuple {
return false; // never reached
}
enum { VAR, PRIM, EXPR } kind;
+ LAddr addr;
unsigned id;
};
@@ -249,22 +286,25 @@ struct Funcs : public list< pair<AType*, CFunction> > {
/// Closure (first-class function with captured lexical bindings)
struct AClosure : public ATuple {
- AClosure(Cursor c, ASymbol* fn, ATuple* p, AST* b, const string& n="")
- : ATuple(c, fn, p, b, NULL), name(n) {}
+ AClosure(Cursor c, ASymbol* fn, ATuple* p, const string& n="")
+ : ATuple(c, fn, p, NULL), type(0), subst(0), name(n) {}
bool operator==(const AST& rhs) const { return this == &rhs; }
- void constrain(TEnv& tenv) const;
+ void constrain(TEnv& tenv, Constraints& c);
void lift(CEnv& cenv);
+ void liftPoly(CEnv& cenv, const vector<AType*>& argsT);
CValue compile(CEnv& cenv);
ATuple* prot() const { return dynamic_cast<ATuple*>(at(1)); }
-private:
+ AType* type;
Funcs funcs;
+ Subst* subst;
+private:
string name;
};
/// Function call/application, e.g. "(func arg1 arg2)"
struct ACall : public ATuple {
ACall(const SExp& e, const ATuple& t) : ATuple(t, e.loc) {}
- void constrain(TEnv& tenv) const;
+ void constrain(TEnv& tenv, Constraints& c);
void lift(CEnv& cenv);
CValue compile(CEnv& cenv);
};
@@ -272,7 +312,8 @@ struct ACall : public ATuple {
/// Definition special form, e.g. "(def x 2)"
struct ADefinition : public ACall {
ADefinition(const SExp& e, const ATuple& t) : ACall(e, t) {}
- void constrain(TEnv& tenv) const;
+ ASymbol* sym() const { return dynamic_cast<ASymbol*>(at(1)); }
+ void constrain(TEnv& tenv, Constraints& c);
void lift(CEnv& cenv);
CValue compile(CEnv& cenv);
};
@@ -280,14 +321,14 @@ struct ADefinition : public ACall {
/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct AIf : public ACall {
AIf(const SExp& e, const ATuple& t) : ACall(e, t) {}
- void constrain(TEnv& tenv) const;
+ void constrain(TEnv& tenv, Constraints& c);
CValue compile(CEnv& cenv);
};
/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct APrimitive : public ACall {
APrimitive(const SExp& e, const ATuple& t) : ACall(e, t) {}
- void constrain(TEnv& tenv) const;
+ void constrain(TEnv& tenv, Constraints& c);
CValue compile(CEnv& cenv);
};
@@ -295,7 +336,7 @@ struct APrimitive : public ACall {
struct AConsCall : public ACall {
AConsCall(const SExp& e, const ATuple& t) : ACall(e, t) {}
AType* functionType(CEnv& cenv);
- void constrain(TEnv& tenv) const;
+ void constrain(TEnv& tenv, Constraints& c);
void lift(CEnv& cenv);
CValue compile(CEnv& cenv);
static Funcs funcs;
@@ -304,14 +345,14 @@ struct AConsCall : public ACall {
/// Car special form, e.g. "(car p)"
struct ACarCall : public ACall {
ACarCall(const SExp& e, const ATuple& t) : ACall(e, t) {}
- void constrain(TEnv& tenv) const;
+ void constrain(TEnv& tenv, Constraints& c);
CValue compile(CEnv& cenv);
};
/// Cdr special form, e.g. "(cdr p)"
struct ACdrCall : public ACall {
ACdrCall(const SExp& e, const ATuple& t) : ACall(e, t) {}
- void constrain(TEnv& tenv) const;
+ void constrain(TEnv& tenv, Constraints& c);
CValue compile(CEnv& cenv);
};
@@ -375,33 +416,70 @@ struct PEnv : private map<const string, ASymbol*> {
* Typing *
***************************************************************************/
+struct Constraint : public pair<AType*,AType*> {
+ Constraint(AType* a, AType* b, Cursor c) : pair<AType*,AType*>(a, b), loc(c) {}
+ Cursor loc;
+};
+
+struct Constraints : public list<Constraint> {
+ void constrain(TEnv& tenv, const AST* o, AType* t);
+};
+
+inline ostream& operator<<(ostream& out, const Constraints& c) {
+ for (Constraints::const_iterator i = c.begin(); i != c.end(); ++i)
+ out << i->first << " : " << i->second << endl;
+ return out;
+}
+
+struct Subst : public map<const AType*,AType*> {
+ Subst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); }
+ static Subst compose(const Subst& delta, const Subst& gamma);
+ AST* apply(AST* ast) const {
+ AType* in = dynamic_cast<AType*>(ast);
+ if (!in) return ast;
+ if (in->kind == AType::EXPR) {
+ AType* out = new AType(in->loc, NULL);
+ for (size_t i = 0; i < in->size(); ++i)
+ out->push_back(apply(in->at(i)));
+ return out;
+ } else {
+ const_iterator i;
+ while ((i = find(in)) != end())
+ in = i->second;
+ return in;
+ }
+ }
+};
+
/// Type-Time Environment
struct TEnv : public Env<const AST*,AType*> {
TEnv(PEnv& p) : penv(p), varID(1) {}
- struct Constraint : public pair<AType*,AType*> {
- Constraint(AType* a, AType* b, Cursor c) : pair<AType*,AType*>(a, b), loc(c) {}
- Cursor loc;
- };
- struct Subst : public map<AType*, AType*> {
- Subst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); }
- };
- typedef list<Constraint> Constraints;
- AType* var(Cursor c=Cursor()) { return new AType(varID++, c); }
- AType* type(const AST* ast) {
- AType** t = ref(ast);
- return t ? *t : def(ast, var(ast->loc));
+
+ AType* fresh(const ASymbol* sym) {
+ assert(sym);
+ return def(sym, new AType(varID++, LAddr(), sym->loc));
+ }
+ AType* var(const AST* ast=0) {
+ const ASymbol* sym = dynamic_cast<const ASymbol*>(ast);
+ if (sym)
+ return deref(lookup(sym));
+
+ map<const AST*, AType*>::iterator v = vars.find(ast);
+ if (v != vars.end())
+ return v->second;
+
+ AType* ret = new AType(varID++, LAddr(), ast ? ast->loc : Cursor());
+ if (ast)
+ vars[ast] = ret;
+
+ return ret;
}
AType* named(const string& name) {
return *ref(penv.sym(name));
}
- void constrain(const AST* o, AType* t) {
- assert(!dynamic_cast<const AType*>(o));
- constraints.push_back(Constraint(type(o), t, o->loc));
- }
- void solve() { apply(unify(constraints)); }
- void apply(const Subst& substs);
static Subst unify(const Constraints& c);
+ map<const AST*, AType*> vars;
PEnv& penv;
Constraints constraints;
unsigned varID;
@@ -422,21 +500,29 @@ struct CEnv {
CEngine engine();
string gensym(const char* s="_") { return (format("%s%d") % s % symID++).str(); }
- void push() { code.push(); vals.push(); }
- void pop() { code.pop(); vals.pop(); }
+ void push() { code.push(); vals.push(); tenv.push(); }
+ void pop() { code.pop(); vals.pop(); tenv.pop(); }
void precompile(AST* obj, CValue value) { vals.def(obj, value); }
CValue compile(AST* obj);
void optimise(CFunction f);
void write(std::ostream& os);
+ AType* type(AST* ast, const Subst& subst = Subst()) const {
+ ASymbol* sym = dynamic_cast<ASymbol*>(ast);
+ if (sym)
+ return tenv.deref(sym->addr);
+ return dynamic_cast<AType*>(tsubst.apply(subst.apply(tenv.vars[ast])));
+ }
- ostream& out;
- ostream& err;
- PEnv& penv;
- TEnv& tenv;
- Code code;
- Vals vals;
- unsigned symID;
- CFunction alloc;
+ ostream& out;
+ ostream& err;
+ PEnv& penv;
+ TEnv& tenv;
+ Code code;
+ Vals vals;
+
+ unsigned symID;
+ CFunction alloc;
+ Subst tsubst;
private:
struct PImpl; ///< Private Implementation
diff --git a/typing.cpp b/typing.cpp
index fdea255..ef28409 100644
--- a/typing.cpp
+++ b/typing.cpp
@@ -15,78 +15,157 @@
* along with Tuplr. If not, see <http://www.gnu.org/licenses/>.
*/
+#include <set>
#include "tuplr.hpp"
+void
+Constraints::constrain(TEnv& tenv, const AST* o, AType* t)
+{
+ assert(!dynamic_cast<const AType*>(o));
+ push_back(Constraint(tenv.var(o), t, o->loc));
+}
+
/***************************************************************************
* AST Type Constraints *
***************************************************************************/
void
-ATuple::constrain(TEnv& tenv) const
+ASymbol::lookup(TEnv& tenv)
+{
+ addr = tenv.lookup(this);
+ if (!addr)
+ throw Error((format("undefined symbol `%1%'") % cppstr).str(), loc);
+}
+
+void
+ASymbol::constrain(TEnv& tenv, Constraints& c)
+{
+ lookup(tenv);
+ AType* t = tenv.deref(addr);
+ if (!t)
+ throw Error((format("unresolved symbol `%1%'") % cppstr).str(), loc);
+ c.push_back(Constraint(tenv.var(this), tenv.deref(addr), loc));
+}
+
+void
+ATuple::constrain(TEnv& tenv, Constraints& c)
{
AType* t = new AType(ATuple(), loc);
- FOREACH(const_iterator, p, *this) {
- (*p)->constrain(tenv);
- t->push_back(tenv.type(*p));
+ FOREACH(iterator, p, *this) {
+ (*p)->constrain(tenv, c);
+ t->push_back(tenv.var(*p));
}
- tenv.constrain(this, t);
+ c.push_back(Constraint(tenv.var(this), t, loc));
}
void
-AClosure::constrain(TEnv& tenv) const
+AClosure::constrain(TEnv& tenv, Constraints& c)
{
- at(1)->constrain(tenv);
- at(2)->constrain(tenv);
- AType* protT = tenv.type(at(1));
- AType* bodyT = tenv.type(at(2));
- tenv.constrain(this, new AType(loc, tenv.penv.sym("Fn"), protT, bodyT, 0));
+ set<ASymbol*> defined;
+ TEnv::Frame frame;
+
+ // Add parameters to environment frame
+ for (size_t i = 0; i < prot()->size(); ++i) {
+ ASymbol* sym = dynamic_cast<ASymbol*>(prot()->at(i));
+ if (!sym)
+ throw Error("parameter name is not a symbol", prot()->at(i)->loc);
+ if (defined.find(sym) != defined.end())
+ throw Error((format("duplicate parameter `%1%'") % sym->str()).str(), sym->loc);
+ defined.insert(sym);
+ frame.push_back(make_pair(sym, (AType*)NULL));
+ }
+
+ // Add internal definitions to environment frame
+ size_t e = 2;
+ for (; e < size(); ++e) {
+ AST* exp = at(e);
+ ADefinition* def = dynamic_cast<ADefinition*>(exp);
+ if (def) {
+ ASymbol* sym = def->sym();
+ if (defined.find(sym) != defined.end())
+ throw Error((format("`%1%' defined twice") % sym->str()).str(), def->loc);
+ defined.insert(def->sym());
+ frame.push_back(make_pair(def->sym(), (AType*)NULL));
+ }
+ }
+
+ tenv.push(frame);
+
+ Constraints cp;
+ cp.push_back(Constraint(tenv.var(this), tenv.var(), loc));
+
+ AType* protT = new AType(ATuple(), loc);
+ for (size_t i = 0; i < prot()->size(); ++i) {
+ AType* tvar = tenv.fresh(dynamic_cast<ASymbol*>(prot()->at(i)));
+ protT->push_back(tvar);
+ assert(frame[i].first == prot()->at(i));
+ frame[i].second = tvar;
+ }
+ c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc));
+
+ for (size_t i = 2; i < size(); ++i)
+ at(i)->constrain(tenv, cp);
+
+ AType* bodyT = tenv.var(at(e-1));
+ Subst tsubst = TEnv::unify(cp);
+ type = new AType(loc, tenv.penv.sym("Fn"), tsubst.apply(protT), tsubst.apply(bodyT), 0);
+
+ tenv.pop();
+
+ c.constrain(tenv, this, type);
+ subst = new Subst(tsubst);
}
void
-ACall::constrain(TEnv& tenv) const
+ACall::constrain(TEnv& tenv, Constraints& c)
{
- FOREACH(const_iterator, p, *this)
- (*p)->constrain(tenv);
- AType* retT = tenv.type(this);
AType* argsT = new AType(ATuple(), loc);
- for (size_t i = 1; i < size(); ++i)
- argsT->push_back(tenv.type(at(i)));
- tenv.constrain(at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
+ TEnv::Frame frame;
+ for (size_t i = 1; i < size(); ++i) {
+ at(i)->constrain(tenv, c);
+ argsT->push_back(tenv.var(at(i)));
+ frame.push_back(make_pair((AST*)NULL, tenv.var(at(i))));
+ }
+ AType* retT = tenv.var();
+
+ at(0)->constrain(tenv, c);
+
+ c.constrain(tenv, at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
+ c.constrain(tenv, this, retT);
}
void
-ADefinition::constrain(TEnv& tenv) const
+ADefinition::constrain(TEnv& tenv, Constraints& c)
{
if (size() != 3) throw Error("`def' requires exactly 2 arguments", loc);
if (!dynamic_cast<const ASymbol*>(at(1)))
throw Error("`def' name is not a symbol", loc);
- FOREACH(const_iterator, p, *this)
- (*p)->constrain(tenv);
- AType* tvar = tenv.type(this);
- tenv.constrain(at(1), tvar);
- tenv.constrain(at(2), tvar);
+ AType* tvar = tenv.var(this);
+ tenv.def(at(1), tvar);
+ at(2)->constrain(tenv, c);
+ c.constrain(tenv, at(2), tvar);
}
void
-AIf::constrain(TEnv& tenv) const
+AIf::constrain(TEnv& tenv, Constraints& c)
{
if (size() < 3) throw Error("`if' requires exactly 3 arguments", loc);
if (size() % 2 != 0) throw Error("`if' missing final else clause", loc);
- FOREACH(const_iterator, p, *this)
- (*p)->constrain(tenv);
- AType* retT = tenv.type(this);
+ for (size_t i = 1; i < size(); ++i)
+ at(i)->constrain(tenv, c);
+ AType* retT = tenv.var(this);
for (size_t i = 1; i < size(); i += 2) {
if (i == size() - 1) {
- tenv.constrain(at(i), retT);
+ c.constrain(tenv, at(i), retT);
} else {
- tenv.constrain(at(i), tenv.named("Bool"));
- tenv.constrain(at(i+1), retT);
+ c.constrain(tenv, at(i), tenv.named("Bool"));
+ c.constrain(tenv, at(i+1), retT);
}
}
}
void
-APrimitive::constrain(TEnv& tenv) const
+APrimitive::constrain(TEnv& tenv, Constraints& c)
{
const string n = dynamic_cast<ASymbol*>(at(0))->str();
enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type;
@@ -101,34 +180,34 @@ APrimitive::constrain(TEnv& tenv) const
else
throw Error((format("unknown primitive `%1%'") % n).str(), loc);
- FOREACH(const_iterator, p, *this)
- (*p)->constrain(tenv);
+ for (size_t i = 1; i < size(); ++i)
+ at(i)->constrain(tenv, c);
switch (type) {
case ARITHMETIC:
if (size() < 3)
throw Error((format("`%1%' requires at least 2 arguments") % n).str(), loc);
for (size_t i = 1; i < size(); ++i)
- tenv.constrain(at(i), tenv.type(this));
+ c.constrain(tenv, at(i), tenv.var(this));
break;
case BINARY:
if (size() != 3)
throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc);
- tenv.constrain(at(1), tenv.type(this));
- tenv.constrain(at(2), tenv.type(this));
+ c.constrain(tenv, at(1), tenv.var(this));
+ c.constrain(tenv, at(2), tenv.var(this));
break;
case LOGICAL:
if (size() != 3)
throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc);
- tenv.constrain(this, tenv.named("Bool"));
- tenv.constrain(at(1), tenv.named("Bool"));
- tenv.constrain(at(2), tenv.named("Bool"));
+ c.constrain(tenv, this, tenv.named("Bool"));
+ c.constrain(tenv, at(1), tenv.named("Bool"));
+ c.constrain(tenv, at(2), tenv.named("Bool"));
break;
case COMPARISON:
if (size() != 3)
throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc);
- tenv.constrain(this, tenv.named("Bool"));
- tenv.constrain(at(1), tenv.type(at(2)));
+ c.constrain(tenv, this, tenv.named("Bool"));
+ c.constrain(tenv, at(1), tenv.var(at(2)));
break;
default:
throw Error((format("unknown primitive `%1%'") % n).str(), loc);
@@ -136,37 +215,37 @@ APrimitive::constrain(TEnv& tenv) const
}
void
-AConsCall::constrain(TEnv& tenv) const
+AConsCall::constrain(TEnv& tenv, Constraints& c)
{
if (size() != 3) throw Error("`cons' requires exactly 2 arguments", loc);
AType* t = new AType(loc, tenv.penv.sym("Pair"), 0);
for (size_t i = 1; i < size(); ++i) {
- at(i)->constrain(tenv);
- t->push_back(tenv.type(at(i)));
+ at(i)->constrain(tenv, c);
+ t->push_back(tenv.var(at(i)));
}
- tenv.constrain(this, t);
+ c.constrain(tenv, this, t);
}
void
-ACarCall::constrain(TEnv& tenv) const
+ACarCall::constrain(TEnv& tenv, Constraints& c)
{
if (size() != 2) throw Error("`car' requires exactly 1 argument", loc);
- at(1)->constrain(tenv);
- AType* carT = tenv.type(this);
+ at(1)->constrain(tenv, c);
+ AType* carT = tenv.var(this);
AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), carT, tenv.var(), 0);
- tenv.constrain(at(1), pairT);
- tenv.constrain(this, carT);
+ c.constrain(tenv, at(1), pairT);
+ c.constrain(tenv, this, carT);
}
void
-ACdrCall::constrain(TEnv& tenv) const
+ACdrCall::constrain(TEnv& tenv, Constraints& c)
{
if (size() != 2) throw Error("`cdr' requires exactly 1 argument", loc);
- at(1)->constrain(tenv);
- AType* cdrT = tenv.type(this);
+ at(1)->constrain(tenv, c);
+ AType* cdrT = tenv.var(this);
AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), tenv.var(), cdrT, 0);
- tenv.constrain(at(1), pairT);
- tenv.constrain(this, cdrT);
+ c.constrain(tenv, at(1), pairT);
+ c.constrain(tenv, this, cdrT);
}
@@ -175,25 +254,26 @@ ACdrCall::constrain(TEnv& tenv) const
***************************************************************************/
static void
-substitute(ATuple* tup, AST* from, AST* to)
+substitute(ATuple* tup, const AST* from, AST* to)
{
if (!tup) return;
for (size_t i = 0; i < tup->size(); ++i)
if (*tup->at(i) == *from)
tup->at(i) = to;
- else
+ else if (tup->at(i) != to)
substitute(dynamic_cast<ATuple*>(tup->at(i)), from, to);
}
-TEnv::Subst
-compose(const TEnv::Subst& delta, const TEnv::Subst& gamma) // TAPL 22.1.1
+
+Subst
+Subst::compose(const Subst& delta, const Subst& gamma) // TAPL 22.1.1
{
- TEnv::Subst r;
- for (TEnv::Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
- TEnv::Subst::const_iterator d = delta.find(g->second);
+ Subst r;
+ for (Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) {
+ Subst::const_iterator d = delta.find(g->second);
r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second));
}
- for (TEnv::Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
+ for (Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) {
if (gamma.find(d->first) == gamma.end())
r.insert(*d);
}
@@ -201,10 +281,10 @@ compose(const TEnv::Subst& delta, const TEnv::Subst& gamma) // TAPL 22.1.1
}
void
-substConstraints(TEnv::Constraints& constraints, AType* s, AType* t)
+substConstraints(Constraints& constraints, AType* s, AType* t)
{
- for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) {
- TEnv::Constraints::iterator next = c; ++next;
+ for (Constraints::iterator c = constraints.begin(); c != constraints.end();) {
+ Constraints::iterator next = c; ++next;
if (*c->first == *s) c->first = t;
if (*c->second == *s) c->second = t;
substitute(c->first, s, t);
@@ -213,7 +293,7 @@ substConstraints(TEnv::Constraints& constraints, AType* s, AType* t)
}
}
-TEnv::Subst
+Subst
TEnv::unify(const Constraints& constraints) // TAPL 22.4
{
if (constraints.empty()) return Subst();
@@ -226,10 +306,10 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4
return unify(cp);
} else if (s->var() && !t->contains(s)) {
substConstraints(cp, s, t);
- return compose(unify(cp), Subst(s, t));
+ return Subst::compose(unify(cp), Subst(s, t));
} else if (t->var() && !s->contains(t)) {
substConstraints(cp, t, s);
- return compose(unify(cp), Subst(t, s));
+ return Subst::compose(unify(cp), Subst(t, s));
} else if (s->kind == AType::EXPR && s->kind == t->kind && s->size() == t->size()) {
for (size_t i = 0; i < s->size(); ++i) {
AType* si = dynamic_cast<AType*>(s->at(i));
@@ -244,14 +324,3 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4
}
}
-void
-TEnv::apply(const TEnv::Subst& substs)
-{
- FOREACH(Subst::const_iterator, s, substs)
- FOREACH(Frame::iterator, t, front())
- if (*t->second == *s->first)
- t->second = s->second;
- else
- substitute(t->second, s->first, s->second);
-}
-