aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm.cpp158
-rw-r--r--tuplr.hpp96
-rw-r--r--typing.cpp148
3 files changed, 222 insertions, 180 deletions
diff --git a/llvm.cpp b/llvm.cpp
index 32577f9..a599872 100644
--- a/llvm.cpp
+++ b/llvm.cpp
@@ -124,6 +124,7 @@ CEnv::compile(AST* obj)
void
CEnv::optimise(CFunction f)
{
+ _pimpl->module->dump();
verifyFunction(*static_cast<Function*>(f));
_pimpl->opt.run(*static_cast<Function*>(f));
}
@@ -139,13 +140,24 @@ CEnv::write(std::ostream& os)
template<> CValue \
ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
template<> void \
-ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) { c.constrain(tenv, this, tenv.named(NAME)); }
+ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) const { c.constrain(tenv, this, tenv.named(NAME)); }
/// Literal template instantiations
LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true))
LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val))
LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false))
+template<typename T>
+T
+checked_cast(AST* ast)
+{
+ T t = dynamic_cast<T>(ast);
+ if (!t)
+ throw Error((format("internal error: `%1%' should be a `%2%'")
+ % typeid(ast).name() % typeid(T).name()).str(), ast->loc);
+ return t;
+}
+
static Function*
compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& protT,
const vector<string> argNames=vector<string>())
@@ -154,8 +166,7 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu
vector<const Type*> cprot;
for (size_t i = 0; i < protT.size(); ++i) {
- AType* at = dynamic_cast<AType*>(protT.at(i));
- if (!at) throw Error("function parameter type isn't");
+ AType* at = checked_cast<AType*>(protT.at(i));
if (!lltype(at)) throw Error("function parameter is untyped");
cprot.push_back(lltype(at));
}
@@ -215,10 +226,8 @@ AClosure::lift(CEnv& cenv)
vector<Value*> args;
const_iterator p = prot()->begin();
size_t i = 0;
- for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) {
- cenv.tenv.def(*p, dynamic_cast<AType*>(protT->at(i++)));
- cenv.vals.def(dynamic_cast<ASymbol*>(*p), &*a);
- }
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
+ cenv.def(checked_cast<ASymbol*>(*p), *p, checked_cast<AType*>(protT->at(i++)), &*a);
// Write function body
try {
@@ -240,65 +249,71 @@ AClosure::lift(CEnv& cenv)
}
void
-AClosure::liftPoly(CEnv& cenv, const vector<AType*>& argsT)
+AClosure::liftCall(CEnv& cenv, const vector<AType*>& argsT)
{
- if (type->concrete())
- return;
-
- throw Error("No polymorphism");
+ TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(this);
+ assert(gt != cenv.tenv.genericTypes.end());
+ AType* genericType = gt->second;
-#if 0
- //Subst tsubst;
+ // Find type and build substitution
assert(argsT.size() == prot()->size());
- for (size_t i = 0; i < argsT.size(); ++i) {
- cenv.err << " " << argsT.at(i)->str();
- //tsubst[*cenv.tenv.ref(prot()->at(i))] = argsT.at(i);
- }
- cenv.err << endl;
-#endif
-}
+ Subst argsSubst;
+ ATuple* genericProtT = dynamic_cast<ATuple*>(genericType->at(1));
+ assert(genericProtT);
+ for (size_t i = 0; i < argsT.size(); ++i)
+ argsSubst[dynamic_cast<AType*>(genericProtT->at(i))] = dynamic_cast<AType*>(argsT.at(i));
+
+ AType* thisType = new AType(*dynamic_cast<ATuple*>(argsSubst.apply(genericType)), loc);
+ cenv.err << "THIS TYPE: " << thisType << endl;
-CValue
-AClosure::compile(CEnv& cenv)
-{
- /*
- cenv.err << "***********************************************" << endl;
- cenv.err << cenv.type(this) << endl;
+ //if (!thisType->concrete())
+ // throw Error("unable to resolve concrete type for function", loc);
- cenv.err << "COMPILING FOR TYPE:";
- Subst tsubst;
- assert(cenv.code.front().size() == prot()->size());
- for (size_t i = 0; i < cenv.code.front().size(); ++i) {
- cenv.err << " (" << cenv.type(prot()->at(i))->str()
- << " -> " << cenv.type(cenv.code.front().at(i).second)->str() << ")";
- tsubst[cenv.tenv.types[prot()->at(i)]] =
- cenv.type(cenv.code.front().at(i).second);
- }
- cenv.err << endl;
+ if (funcs.find(thisType))
+ return;
- Subst subst = Subst::compose(tsubst, cenv.tsubst);
- AType* concreteType = subst.apply(type);
- if (!concreteType->concrete())
- throw Error("compiled function has non-concrete type", loc);
+ ATuple* protT = dynamic_cast<ATuple*>(thisType->at(1));
- cenv.err << "*********** CONCRETE TYPE: " << concreteType->str() << endl;
- */
+ // Write function declaration
+ string name = this->name == "" ? cenv.gensym("_fn") : this->name;
+ Function* f = compileFunction(cenv, name,
+ lltype(dynamic_cast<AType*>(thisType->at(thisType->size() - 1))),
+ *protT);
- //CValue ret = funcs.find(concreteType);
- //cenv.err << "VALUE FOR TYPE " << concreteType->str() << " : " << ret << endl;
- //return ret;
- return NULL;
+ cenv.push();
+ Subst oldSubst = cenv.tsubst;
+ cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, *subst));
+
+ // Bind argument values in CEnv
+ vector<Value*> args;
+ const_iterator p = prot()->begin();
+ size_t i = 0;
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
+ cenv.def(checked_cast<ASymbol*>(*p), *p, checked_cast<AType*>(protT->at(i++)), &*a);
+
+ // Write function body
+ try {
+ // Define value first for recursion
+ cenv.precompile(this, f);
+ funcs.push_back(make_pair(thisType, f));
+
+ CValue retVal = cenv.compile(at(2));
+ llengine(cenv)->builder.CreateRet(LLVal(retVal)); // Finish function
+ cenv.optimise(LLFunc(f));
+
+ } catch (Error& e) {
+ f->eraseFromParent(); // Error reading body, remove function
+ cenv.pop();
+ throw e;
+ }
+ cenv.tsubst = oldSubst;
+ cenv.pop();
}
-template<typename T>
-T
-checked_cast(AST* ast)
+CValue
+AClosure::compile(CEnv& cenv)
{
- T t = dynamic_cast<T>(ast);
- if (!t)
- throw Error((format("internal error: `%1%' should be a `%2%'")
- % typeid(ast).name() % typeid(T).name()).str(), ast->loc);
- return t;
+ return NULL;
}
static
@@ -306,7 +321,7 @@ AST*
maybeLookup(CEnv& cenv, AST* ast)
{
ASymbol* s = dynamic_cast<ASymbol*>(ast);
- if (s)
+ if (s && s->addr)
return cenv.code.deref(s->addr);
return ast;
}
@@ -325,18 +340,19 @@ ACall::lift(CEnv& cenv)
}
if (!c) return; // Primitive
-
- // Extend environment with bound and typed parameters
- cenv.push();
+
if (c->prot()->size() < size() - 1)
throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc);
if (c->prot()->size() > size() - 1)
throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc);
+ // Extend environment with bound and typed parameters
+ cenv.push();
+
for (size_t i = 1; i < size(); ++i)
- cenv.code.def(checked_cast<ASymbol*>(c->prot()->at(i-1)), at(i));
+ cenv.def(checked_cast<ASymbol*>(c->prot()->at(i-1)), at(i), cenv.type(at(i)), NULL);
- c->liftPoly(cenv, argsT); // Lift called closure
+ c->liftCall(cenv, argsT); // Lift called closure
cenv.pop(); // Restore environment
}
@@ -347,16 +363,13 @@ ACall::compile(CEnv& cenv)
if (!c) return NULL; // Primitive
- if (c->prot()->size() < size() - 1)
- throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc);
- if (c->prot()->size() > size() - 1)
- throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc);
-
AType* protT = new AType(loc, NULL);
for (size_t i = 1; i < size(); ++i)
protT->push_back(cenv.type(at(i)));
-
- AType* polyT = c->type;
+
+ TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(c);
+ assert(gt != cenv.tenv.genericTypes.end());
+ AType* polyT = gt->second;
AType* fnT = new AType(loc, cenv.penv.sym("Fn"), protT, polyT->at(2), 0);
Function* f = (Function*)c->funcs.find(fnT);
@@ -374,7 +387,8 @@ ADefinition::lift(CEnv& cenv)
{
if (cenv.code.lookup(checked_cast<ASymbol*>(at(1))))
throw Error(string("`") + at(1)->str() + "' redefined", loc);
- cenv.code.def((ASymbol*)at(1), at(2)); // Define first for recursion
+ // Define first for recursion
+ cenv.def(checked_cast<ASymbol*>(at(1)), at(2), cenv.type(at(2)), NULL);
at(2)->lift(cenv);
}
@@ -630,11 +644,11 @@ eval(CEnv& cenv, const string& name, istream& is)
int
repl(CEnv& cenv)
{
- Constraints c;
while (1) {
cenv.out << "() ";
cenv.out.flush();
Cursor cursor("(stdin)");
+ Constraints c;
try {
SExp exp = readExpression(cursor, std::cin);
@@ -643,6 +657,10 @@ repl(CEnv& cenv)
AST* body = cenv.penv.parse(exp); // Parse input
body->constrain(cenv.tenv, c); // Constrain types
+
+ for (TEnv::GenericTypes::const_iterator i = cenv.tenv.genericTypes.begin();
+ i != cenv.tenv.genericTypes.end(); ++i)
+ c.push_back(Constraint(cenv.tenv.var(i->first), i->second, i->first->loc));
cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints
diff --git a/tuplr.hpp b/tuplr.hpp
index 4d714d4..72c01f3 100644
--- a/tuplr.hpp
+++ b/tuplr.hpp
@@ -84,8 +84,8 @@ template<typename K, typename V>
struct Env : public list< vector< pair<K,V> > > {
typedef vector< pair<K,V> > Frame;
Env() : list<Frame>(1) {}
- void push(Frame f=Frame()) { list<Frame>::push_front(f); }
- void pop() { assert(!this->empty()); list<Frame>::pop_front(); }
+ virtual void push(Frame f=Frame()) { list<Frame>::push_front(f); }
+ virtual void pop() { assert(!this->empty()); list<Frame>::pop_front(); }
const V& def(const K& k, const V& v) {
for (typename Frame::iterator b = this->begin()->begin(); b != this->begin()->end(); ++b)
if (b->first == k)
@@ -112,6 +112,9 @@ struct Env : public list< vector< pair<K,V> > > {
assert(addr);
typename Env::iterator f = this->begin();
for (unsigned u = 1; u < addr.up; ++u, ++f) { assert(f != this->end()); }
+ if (!(f->size() > addr.over - 1)) {
+ std::cerr << "WTF: " << addr << " : " << this->size() << "." << f->size() << endl;
+ }
assert(f->size() > addr.over - 1);
return (*f)[addr.over - 1].second;
}
@@ -155,7 +158,7 @@ struct AST {
virtual ~AST() {}
virtual bool operator==(const AST& o) const = 0;
virtual bool contains(const AST* child) const { return false; }
- virtual void constrain(TEnv& tenv, Constraints& c) {}
+ virtual void constrain(TEnv& tenv, Constraints& c) const {}
virtual void lift(CEnv& cenv) {}
string str() const { ostringstream ss; ss << this; return ss.str(); }
Cursor loc;
@@ -172,7 +175,7 @@ struct ALiteral : public AST {
const ALiteral<VT>* r = dynamic_cast<const ALiteral<VT>*>(&rhs);
return (r && (val == r->val));
}
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
CValue compile(CEnv& cenv);
const VT val;
};
@@ -180,10 +183,9 @@ struct ALiteral : public AST {
/// Symbol, e.g. "a"
struct ASymbol : public AST {
bool operator==(const AST& rhs) const { return this == &rhs; }
- void lookup(TEnv& tenv);
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
CValue compile(CEnv& cenv);
- LAddr addr;
+ mutable LAddr addr;
private:
friend class PEnv;
ASymbol(const string& s, Cursor c) : AST(c), cppstr(s) {}
@@ -225,7 +227,7 @@ struct ATuple : public AST, public vector<AST*> {
return true;
return false;
}
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
CValue compile(CEnv& cenv) { throw Error("tuple compiled"); }
};
@@ -287,16 +289,15 @@ struct Funcs : public list< pair<AType*, CFunction> > {
/// Closure (first-class function with captured lexical bindings)
struct AClosure : public ATuple {
AClosure(Cursor c, ASymbol* fn, ATuple* p, const string& n="")
- : ATuple(c, fn, p, NULL), type(0), subst(0), name(n) {}
+ : ATuple(c, fn, p, NULL), subst(0), name(n) {}
bool operator==(const AST& rhs) const { return this == &rhs; }
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
void lift(CEnv& cenv);
- void liftPoly(CEnv& cenv, const vector<AType*>& argsT);
+ void liftCall(CEnv& cenv, const vector<AType*>& argsT);
CValue compile(CEnv& cenv);
ATuple* prot() const { return dynamic_cast<ATuple*>(at(1)); }
- AType* type;
- Funcs funcs;
- Subst* subst;
+ Funcs funcs;
+ mutable Subst* subst;
private:
string name;
};
@@ -304,7 +305,7 @@ private:
/// Function call/application, e.g. "(func arg1 arg2)"
struct ACall : public ATuple {
ACall(const SExp& e, const ATuple& t) : ATuple(t, e.loc) {}
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
void lift(CEnv& cenv);
CValue compile(CEnv& cenv);
};
@@ -313,7 +314,7 @@ struct ACall : public ATuple {
struct ADefinition : public ACall {
ADefinition(const SExp& e, const ATuple& t) : ACall(e, t) {}
ASymbol* sym() const { return dynamic_cast<ASymbol*>(at(1)); }
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
void lift(CEnv& cenv);
CValue compile(CEnv& cenv);
};
@@ -321,14 +322,14 @@ struct ADefinition : public ACall {
/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct AIf : public ACall {
AIf(const SExp& e, const ATuple& t) : ACall(e, t) {}
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
CValue compile(CEnv& cenv);
};
/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct APrimitive : public ACall {
APrimitive(const SExp& e, const ATuple& t) : ACall(e, t) {}
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
CValue compile(CEnv& cenv);
};
@@ -336,7 +337,7 @@ struct APrimitive : public ACall {
struct AConsCall : public ACall {
AConsCall(const SExp& e, const ATuple& t) : ACall(e, t) {}
AType* functionType(CEnv& cenv);
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
void lift(CEnv& cenv);
CValue compile(CEnv& cenv);
static Funcs funcs;
@@ -345,14 +346,14 @@ struct AConsCall : public ACall {
/// Car special form, e.g. "(car p)"
struct ACarCall : public ACall {
ACarCall(const SExp& e, const ATuple& t) : ACall(e, t) {}
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
CValue compile(CEnv& cenv);
};
/// Cdr special form, e.g. "(cdr p)"
struct ACdrCall : public ACall {
ACdrCall(const SExp& e, const ATuple& t) : ACall(e, t) {}
- void constrain(TEnv& tenv, Constraints& c);
+ void constrain(TEnv& tenv, Constraints& c) const;
CValue compile(CEnv& cenv);
};
@@ -432,7 +433,7 @@ inline ostream& operator<<(ostream& out, const Constraints& c) {
}
struct Subst : public map<const AType*,AType*> {
- Subst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); }
+ Subst(AType* s=0, AType* t=0) { if (s && t) { assert(s != t); insert(make_pair(s, t)); } }
static Subst compose(const Subst& delta, const Subst& gamma);
AST* apply(AST* ast) const {
AType* in = dynamic_cast<AType*>(ast);
@@ -443,9 +444,13 @@ struct Subst : public map<const AType*,AType*> {
out->push_back(apply(in->at(i)));
return out;
} else {
- const_iterator i;
- while ((i = find(in)) != end())
+ Subst copy(*this);
+ iterator i;
+ while ((i = copy.find(in)) != copy.end()) {
+ cerr << "IN: " << in << endl;
in = i->second;
+ copy.erase(i);
+ }
return in;
}
}
@@ -460,11 +465,15 @@ struct TEnv : public Env<const AST*,AType*> {
return def(sym, new AType(varID++, LAddr(), sym->loc));
}
AType* var(const AST* ast=0) {
+ /*GenericTypes::iterator g = genericTypes.find(dynamic_casdt<AClosure*>(ast));
+ if (g != vars.end())
+ return g->second;*/
+
const ASymbol* sym = dynamic_cast<const ASymbol*>(ast);
if (sym)
return deref(lookup(sym));
- map<const AST*, AType*>::iterator v = vars.find(ast);
+ Vars::iterator v = vars.find(ast);
if (v != vars.end())
return v->second;
@@ -479,10 +488,13 @@ struct TEnv : public Env<const AST*,AType*> {
}
static Subst unify(const Constraints& c);
- map<const AST*, AType*> vars;
- PEnv& penv;
- Constraints constraints;
- unsigned varID;
+ typedef map<const AST*, AType*> Vars;
+ typedef map<const AClosure*, AType*> GenericTypes;
+ Vars vars;
+ GenericTypes genericTypes;
+ PEnv& penv;
+ Constraints constraints;
+ unsigned varID;
};
@@ -512,17 +524,23 @@ struct CEnv {
return tenv.deref(sym->addr);
return dynamic_cast<AType*>(tsubst.apply(subst.apply(tenv.vars[ast])));
}
+ void def(ASymbol* sym, AST* c, AType* t, CValue v) {
+ code.def(sym, c);
+ tenv.def(sym, t);
+ vals.def(sym, v);
+ }
- ostream& out;
- ostream& err;
- PEnv& penv;
- TEnv& tenv;
- Code code;
- Vals vals;
-
- unsigned symID;
- CFunction alloc;
- Subst tsubst;
+ ostream& out;
+ ostream& err;
+ PEnv& penv;
+ TEnv& tenv;
+ Code code;
+ Vals vals;
+
+
+ unsigned symID;
+ CFunction alloc;
+ Subst tsubst;
private:
struct PImpl; ///< Private Implementation
diff --git a/typing.cpp b/typing.cpp
index ef28409..1d86f3d 100644
--- a/typing.cpp
+++ b/typing.cpp
@@ -30,28 +30,24 @@ Constraints::constrain(TEnv& tenv, const AST* o, AType* t)
***************************************************************************/
void
-ASymbol::lookup(TEnv& tenv)
+ASymbol::constrain(TEnv& tenv, Constraints& c) const
{
addr = tenv.lookup(this);
if (!addr)
throw Error((format("undefined symbol `%1%'") % cppstr).str(), loc);
-}
-
-void
-ASymbol::constrain(TEnv& tenv, Constraints& c)
-{
- lookup(tenv);
AType* t = tenv.deref(addr);
if (!t)
throw Error((format("unresolved symbol `%1%'") % cppstr).str(), loc);
- c.push_back(Constraint(tenv.var(this), tenv.deref(addr), loc));
+ AType* var = tenv.deref(addr);
+ var->addr = addr;
+ c.push_back(Constraint(tenv.var(this), var, loc));
}
void
-ATuple::constrain(TEnv& tenv, Constraints& c)
+ATuple::constrain(TEnv& tenv, Constraints& c) const
{
AType* t = new AType(ATuple(), loc);
- FOREACH(iterator, p, *this) {
+ FOREACH(const_iterator, p, *this) {
(*p)->constrain(tenv, c);
t->push_back(tenv.var(*p));
}
@@ -59,95 +55,105 @@ ATuple::constrain(TEnv& tenv, Constraints& c)
}
void
-AClosure::constrain(TEnv& tenv, Constraints& c)
+AClosure::constrain(TEnv& tenv, Constraints& c) const
{
- set<ASymbol*> defined;
- TEnv::Frame frame;
-
- // Add parameters to environment frame
- for (size_t i = 0; i < prot()->size(); ++i) {
- ASymbol* sym = dynamic_cast<ASymbol*>(prot()->at(i));
- if (!sym)
- throw Error("parameter name is not a symbol", prot()->at(i)->loc);
- if (defined.find(sym) != defined.end())
- throw Error((format("duplicate parameter `%1%'") % sym->str()).str(), sym->loc);
- defined.insert(sym);
- frame.push_back(make_pair(sym, (AType*)NULL));
- }
-
- // Add internal definitions to environment frame
- size_t e = 2;
- for (; e < size(); ++e) {
- AST* exp = at(e);
- ADefinition* def = dynamic_cast<ADefinition*>(exp);
- if (def) {
- ASymbol* sym = def->sym();
+ AType* genericType;
+ TEnv::GenericTypes::const_iterator gt = tenv.genericTypes.find(this);
+ if (gt != tenv.genericTypes.end()) {
+ genericType = gt->second;
+ } else {
+ set<ASymbol*> defined;
+ TEnv::Frame frame;
+
+ // Add parameters to environment frame
+ for (size_t i = 0; i < prot()->size(); ++i) {
+ ASymbol* sym = dynamic_cast<ASymbol*>(prot()->at(i));
+ if (!sym)
+ throw Error("parameter name is not a symbol", prot()->at(i)->loc);
if (defined.find(sym) != defined.end())
- throw Error((format("`%1%' defined twice") % sym->str()).str(), def->loc);
- defined.insert(def->sym());
- frame.push_back(make_pair(def->sym(), (AType*)NULL));
+ throw Error((format("duplicate parameter `%1%'") % sym->str()).str(), sym->loc);
+ defined.insert(sym);
+ frame.push_back(make_pair(sym, (AType*)NULL));
}
- }
-
- tenv.push(frame);
-
- Constraints cp;
- cp.push_back(Constraint(tenv.var(this), tenv.var(), loc));
-
- AType* protT = new AType(ATuple(), loc);
- for (size_t i = 0; i < prot()->size(); ++i) {
- AType* tvar = tenv.fresh(dynamic_cast<ASymbol*>(prot()->at(i)));
- protT->push_back(tvar);
- assert(frame[i].first == prot()->at(i));
- frame[i].second = tvar;
- }
- c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc));
- for (size_t i = 2; i < size(); ++i)
- at(i)->constrain(tenv, cp);
+ // Add internal definitions to environment frame
+ size_t e = 2;
+ for (; e < size(); ++e) {
+ AST* exp = at(e);
+ ADefinition* def = dynamic_cast<ADefinition*>(exp);
+ if (def) {
+ ASymbol* sym = def->sym();
+ if (defined.find(sym) != defined.end())
+ throw Error((format("`%1%' defined twice") % sym->str()).str(), def->loc);
+ defined.insert(def->sym());
+ frame.push_back(make_pair(def->sym(), (AType*)NULL));
+ }
+ }
- AType* bodyT = tenv.var(at(e-1));
- Subst tsubst = TEnv::unify(cp);
- type = new AType(loc, tenv.penv.sym("Fn"), tsubst.apply(protT), tsubst.apply(bodyT), 0);
+ tenv.push(frame);
- tenv.pop();
+ Constraints cp;
+ cp.push_back(Constraint(tenv.var(this), tenv.var(), loc));
- c.constrain(tenv, this, type);
- subst = new Subst(tsubst);
+ AType* protT = new AType(ATuple(), loc);
+ for (size_t i = 0; i < prot()->size(); ++i) {
+ AType* tvar = tenv.fresh(dynamic_cast<ASymbol*>(prot()->at(i)));
+ protT->push_back(tvar);
+ assert(frame[i].first == prot()->at(i));
+ frame[i].second = tvar;
+ }
+ c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc));
+
+ for (size_t i = 2; i < size(); ++i)
+ at(i)->constrain(tenv, cp);
+
+ AType* bodyT = tenv.var(at(e-1));
+ Subst tsubst = TEnv::unify(cp);
+ genericType = new AType(loc, tenv.penv.sym("Fn"),
+ tsubst.apply(protT), tsubst.apply(bodyT), 0);
+ tenv.genericTypes.insert(make_pair(this, genericType));
+ tenv.def(this, genericType);
+
+ tenv.pop();
+ subst = new Subst(tsubst);
+ }
+
+ c.constrain(tenv, this, genericType);
+ //for (Constraints::const_iterator i = cp.begin(); i != cp.end(); ++i)
+ // c.push_back(*i);
}
void
-ACall::constrain(TEnv& tenv, Constraints& c)
+ACall::constrain(TEnv& tenv, Constraints& c) const
{
+ std::cerr << "CONSTRAIN CALL" << endl;
+
+ at(0)->constrain(tenv, c);
AType* argsT = new AType(ATuple(), loc);
- TEnv::Frame frame;
for (size_t i = 1; i < size(); ++i) {
at(i)->constrain(tenv, c);
argsT->push_back(tenv.var(at(i)));
- frame.push_back(make_pair((AST*)NULL, tenv.var(at(i))));
}
AType* retT = tenv.var();
-
- at(0)->constrain(tenv, c);
c.constrain(tenv, at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0));
c.constrain(tenv, this, retT);
}
void
-ADefinition::constrain(TEnv& tenv, Constraints& c)
+ADefinition::constrain(TEnv& tenv, Constraints& c) const
{
if (size() != 3) throw Error("`def' requires exactly 2 arguments", loc);
if (!dynamic_cast<const ASymbol*>(at(1)))
throw Error("`def' name is not a symbol", loc);
- AType* tvar = tenv.var(this);
+ AType* tvar = tenv.var(at(2));
tenv.def(at(1), tvar);
at(2)->constrain(tenv, c);
- c.constrain(tenv, at(2), tvar);
+ c.constrain(tenv, this, tvar);
}
void
-AIf::constrain(TEnv& tenv, Constraints& c)
+AIf::constrain(TEnv& tenv, Constraints& c) const
{
if (size() < 3) throw Error("`if' requires exactly 3 arguments", loc);
if (size() % 2 != 0) throw Error("`if' missing final else clause", loc);
@@ -165,7 +171,7 @@ AIf::constrain(TEnv& tenv, Constraints& c)
}
void
-APrimitive::constrain(TEnv& tenv, Constraints& c)
+APrimitive::constrain(TEnv& tenv, Constraints& c) const
{
const string n = dynamic_cast<ASymbol*>(at(0))->str();
enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type;
@@ -215,7 +221,7 @@ APrimitive::constrain(TEnv& tenv, Constraints& c)
}
void
-AConsCall::constrain(TEnv& tenv, Constraints& c)
+AConsCall::constrain(TEnv& tenv, Constraints& c) const
{
if (size() != 3) throw Error("`cons' requires exactly 2 arguments", loc);
AType* t = new AType(loc, tenv.penv.sym("Pair"), 0);
@@ -227,7 +233,7 @@ AConsCall::constrain(TEnv& tenv, Constraints& c)
}
void
-ACarCall::constrain(TEnv& tenv, Constraints& c)
+ACarCall::constrain(TEnv& tenv, Constraints& c) const
{
if (size() != 2) throw Error("`car' requires exactly 1 argument", loc);
at(1)->constrain(tenv, c);
@@ -238,7 +244,7 @@ ACarCall::constrain(TEnv& tenv, Constraints& c)
}
void
-ACdrCall::constrain(TEnv& tenv, Constraints& c)
+ACdrCall::constrain(TEnv& tenv, Constraints& c) const
{
if (size() != 2) throw Error("`cdr' requires exactly 1 argument", loc);
at(1)->constrain(tenv, c);