aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm.cpp6
-rw-r--r--tuplr.cpp6
-rw-r--r--tuplr.hpp28
-rw-r--r--typing.cpp46
4 files changed, 56 insertions, 30 deletions
diff --git a/llvm.cpp b/llvm.cpp
index 240cb2f..2493056 100644
--- a/llvm.cpp
+++ b/llvm.cpp
@@ -308,7 +308,7 @@ maybeLookup(CEnv& cenv, AST* ast)
{
ASymbol* s = dynamic_cast<ASymbol*>(ast);
if (s && s->addr)
- return cenv.code.deref(s->addr);
+ return cenv.tenv.deref(s->addr).first;
return ast;
}
@@ -371,8 +371,6 @@ ACall::compile(CEnv& cenv)
void
ADefinition::lift(CEnv& cenv)
{
- if (cenv.code.lookup(at(1)->as<ASymbol*>()))
- throw Error(string("`") + at(1)->str() + "' redefined", loc);
// Define first for recursion
cenv.def(at(1)->as<ASymbol*>(), at(2), cenv.type(at(2)), NULL);
at(2)->lift(cenv);
@@ -381,6 +379,8 @@ ADefinition::lift(CEnv& cenv)
CValue
ADefinition::compile(CEnv& cenv)
{
+ // Define first for recursion
+ cenv.def(at(1)->as<ASymbol*>(), at(2), cenv.type(at(2)), NULL);
return cenv.compile(at(2));
}
diff --git a/tuplr.cpp b/tuplr.cpp
index 3429a9f..1a059a5 100644
--- a/tuplr.cpp
+++ b/tuplr.cpp
@@ -143,9 +143,9 @@ void
initLang(PEnv& penv, TEnv& tenv)
{
// Types
- tenv.def(penv.sym("Bool"), new AType(penv.sym("Bool")));
- tenv.def(penv.sym("Int"), new AType(penv.sym("Int")));
- tenv.def(penv.sym("Float"), new AType(penv.sym("Float")));
+ tenv.def(penv.sym("Bool"), make_pair((AST*)NULL, new AType(penv.sym("Bool"))));
+ tenv.def(penv.sym("Int"), make_pair((AST*)NULL, new AType(penv.sym("Int"))));
+ tenv.def(penv.sym("Float"), make_pair((AST*)NULL, new AType(penv.sym("Float"))));
// Literals
static bool trueVal = true;
diff --git a/tuplr.hpp b/tuplr.hpp
index 506742a..5e60e75 100644
--- a/tuplr.hpp
+++ b/tuplr.hpp
@@ -453,17 +453,18 @@ struct Subst : public map<const AType*,AType*> {
};
/// Type-Time Environment
-struct TEnv : public Env<const ASymbol*,AType*> {
+struct TEnv : public Env< const ASymbol*, pair<AST*, AType*> > {
TEnv(PEnv& p) : penv(p), varID(1) {}
-
AType* fresh(const ASymbol* sym) {
assert(sym);
- return def(sym, new AType(varID++, LAddr(), sym->loc));
+ AType* ret = new AType(varID++, LAddr(), sym->loc);
+ def(sym, make_pair((AST*)NULL, ret));
+ return ret;
}
AType* var(const AST* ast=0) {
const ASymbol* sym = dynamic_cast<const ASymbol*>(ast);
if (sym)
- return deref(lookup(sym));
+ return deref(lookup(sym)).second;
Vars::iterator v = vars.find(ast);
if (v != vars.end())
@@ -476,7 +477,14 @@ struct TEnv : public Env<const ASymbol*,AType*> {
return ret;
}
AType* named(const string& name) {
- return *ref(penv.sym(name));
+ return ref(penv.sym(name))->second;
+ }
+ AST* resolve(AST* ast) {
+ ASymbol* sym = dynamic_cast<ASymbol*>(ast);
+ if (sym)
+ return ref(sym)->first;
+ else
+ return ast;
}
static Subst unify(const Constraints& c);
@@ -504,8 +512,8 @@ struct CEnv {
CEngine engine();
string gensym(const char* s="_") { return (format("%s%d") % s % symID++).str(); }
- void push() { code.push(); vals.push(); tenv.push(); }
- void pop() { code.pop(); vals.pop(); tenv.pop(); }
+ void push() { tenv.push(); vals.push(); }
+ void pop() { tenv.pop(); vals.pop(); }
void precompile(AST* obj, CValue value) { vals.def(obj, value); }
CValue compile(AST* obj);
void optimise(CFunction f);
@@ -513,18 +521,18 @@ struct CEnv {
AType* type(AST* ast, const Subst& subst = Subst()) const {
ASymbol* sym = dynamic_cast<ASymbol*>(ast);
if (sym)
- return tenv.deref(sym->addr);
+ return tenv.deref(sym->addr).second;
return dynamic_cast<AType*>(tsubst.apply(subst.apply(tenv.vars[ast])));
}
void def(ASymbol* sym, AST* c, AType* t, CValue v) {
- code.def(sym, c); tenv.def(sym, t); vals.def(sym, v);
+ tenv.def(sym, make_pair(c, t));
+ vals.def(sym, v);
}
ostream& out;
ostream& err;
PEnv& penv;
TEnv& tenv;
- Code code;
Vals vals;
unsigned symID;
diff --git a/typing.cpp b/typing.cpp
index 105140e..e042283 100644
--- a/typing.cpp
+++ b/typing.cpp
@@ -35,10 +35,10 @@ ASymbol::constrain(TEnv& tenv, Constraints& c) const
addr = tenv.lookup(this);
if (!addr)
throw Error((format("undefined symbol `%1%'") % cppstr).str(), loc);
- AType* t = tenv.deref(addr);
- assert(t);
- t->addr = addr;
- c.push_back(Constraint(tenv.var(this), t, loc));
+ pair<AST*, AType*>& t = tenv.deref(addr);
+ AType* tvar = tenv.var(t.second);
+ c.push_back(Constraint(tenv.var(this), tvar, loc));
+ c.push_back(Constraint(t.second, tvar, loc));
}
void
@@ -71,7 +71,7 @@ AClosure::constrain(TEnv& tenv, Constraints& c) const
if (defined.find(sym) != defined.end())
throw Error((format("duplicate parameter `%1%'") % sym->str()).str(), sym->loc);
defined.insert(sym);
- frame.push_back(make_pair(sym, (AType*)NULL));
+ frame.push_back(make_pair(sym, make_pair((AST*)NULL, (AType*)NULL)));
}
// Add internal definitions to environment frame
@@ -84,7 +84,7 @@ AClosure::constrain(TEnv& tenv, Constraints& c) const
if (defined.find(sym) != defined.end())
throw Error((format("`%1%' defined twice") % sym->str()).str(), def->loc);
defined.insert(def->sym());
- frame.push_back(make_pair(def->sym(), (AType*)NULL));
+ frame.push_back(make_pair(def->sym(), make_pair(def->at(2), (AType*)NULL)));
}
}
@@ -98,7 +98,8 @@ AClosure::constrain(TEnv& tenv, Constraints& c) const
AType* tvar = tenv.fresh(dynamic_cast<ASymbol*>(prot()->at(i)));
protT->push_back(tvar);
assert(frame[i].first == prot()->at(i));
- frame[i].second = tvar;
+ frame[i].second.first = prot()->at(i);
+ frame[i].second.second = tvar;
}
c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc));
@@ -110,7 +111,6 @@ AClosure::constrain(TEnv& tenv, Constraints& c) const
genericType = new AType(loc, tenv.penv.sym("Fn"),
tsubst.apply(protT), tsubst.apply(bodyT), 0);
tenv.genericTypes.insert(make_pair(this, genericType));
- //tenv.def(this, genericType);
tenv.pop();
subst = new Subst(tsubst);
@@ -123,13 +123,28 @@ void
ACall::constrain(TEnv& tenv, Constraints& c) const
{
at(0)->constrain(tenv, c);
- AType* argsT = new AType(ATuple(), loc);
- for (size_t i = 1; i < size(); ++i) {
+ for (size_t i = 1; i < size(); ++i)
at(i)->constrain(tenv, c);
- argsT->push_back(tenv.var(at(i)));
+
+ AST* callee = tenv.resolve(at(0));
+ AClosure* closure = dynamic_cast<AClosure*>(callee);
+ if (closure) {
+ if (size() - 1 != closure->prot()->size())
+ throw Error("incorrect number of arguments", loc);
+ TEnv::GenericTypes::iterator gt = tenv.genericTypes.find(closure);
+ if (gt != tenv.genericTypes.end()) {
+ for (size_t i = 1; i < size(); ++i)
+ c.constrain(tenv, at(i), gt->second->at(1)->as<ATuple*>()->at(i-1)->as<AType*>());
+ AType* retT = tenv.var(this);
+ c.constrain(tenv, at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), tenv.var(), retT, 0));
+ c.constrain(tenv, this, retT);
+ return;
+ }
}
+ AType* argsT = new AType(ATuple(), loc);
+ for (size_t i = 1; i < size(); ++i)
+ argsT->push_back(tenv.var(at(i)));
AType* retT = tenv.var();
-
c.constrain(tenv, at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
c.constrain(tenv, this, retT);
}
@@ -138,10 +153,13 @@ void
ADefinition::constrain(TEnv& tenv, Constraints& c) const
{
if (size() != 3) throw Error("`def' requires exactly 2 arguments", loc);
- if (!dynamic_cast<const ASymbol*>(at(1)))
+ const ASymbol* sym = dynamic_cast<const ASymbol*>(at(1));
+ if (!sym)
throw Error("`def' name is not a symbol", loc);
+ if (tenv.lookup(sym))
+ throw Error(string("`") + at(1)->str() + "' redefined", loc);
AType* tvar = tenv.var(at(2));
- tenv.def(at(1)->as<ASymbol*>(), tvar);
+ tenv.def(sym, make_pair(at(2), tvar));
at(2)->constrain(tenv, c);
c.constrain(tenv, this, tvar);
}