diff options
author | David Robillard <d@drobilla.net> | 2009-03-15 01:09:19 +0000 |
---|---|---|
committer | David Robillard <d@drobilla.net> | 2009-03-15 01:09:19 +0000 |
commit | 117b9fbb3d0737f44cf3f8f3f1a3f964b1f9e777 (patch) | |
tree | 56c089fcfa6b1f6a510cdbd013a38363bea19fd0 | |
parent | 053a6e588b143bdd765113cdcd53ae7ef39c1a6c (diff) | |
download | resp-117b9fbb3d0737f44cf3f8f3f1a3f964b1f9e777.tar.gz resp-117b9fbb3d0737f44cf3f8f3f1a3f964b1f9e777.tar.bz2 resp-117b9fbb3d0737f44cf3f8f3f1a3f964b1f9e777.zip |
Fix recursion.
git-svn-id: http://svn.drobilla.net/resp/tuplr@89 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r-- | llvm.cpp | 158 | ||||
-rw-r--r-- | tuplr.hpp | 96 | ||||
-rw-r--r-- | typing.cpp | 148 |
3 files changed, 222 insertions, 180 deletions
@@ -124,6 +124,7 @@ CEnv::compile(AST* obj) void CEnv::optimise(CFunction f) { + _pimpl->module->dump(); verifyFunction(*static_cast<Function*>(f)); _pimpl->opt.run(*static_cast<Function*>(f)); } @@ -139,13 +140,24 @@ CEnv::write(std::ostream& os) template<> CValue \ ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \ template<> void \ -ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) { c.constrain(tenv, this, tenv.named(NAME)); } +ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) const { c.constrain(tenv, this, tenv.named(NAME)); } /// Literal template instantiations LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)) LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)) LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)) +template<typename T> +T +checked_cast(AST* ast) +{ + T t = dynamic_cast<T>(ast); + if (!t) + throw Error((format("internal error: `%1%' should be a `%2%'") + % typeid(ast).name() % typeid(T).name()).str(), ast->loc); + return t; +} + static Function* compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& protT, const vector<string> argNames=vector<string>()) @@ -154,8 +166,7 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu vector<const Type*> cprot; for (size_t i = 0; i < protT.size(); ++i) { - AType* at = dynamic_cast<AType*>(protT.at(i)); - if (!at) throw Error("function parameter type isn't"); + AType* at = checked_cast<AType*>(protT.at(i)); if (!lltype(at)) throw Error("function parameter is untyped"); cprot.push_back(lltype(at)); } @@ -215,10 +226,8 @@ AClosure::lift(CEnv& cenv) vector<Value*> args; const_iterator p = prot()->begin(); size_t i = 0; - for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) { - cenv.tenv.def(*p, dynamic_cast<AType*>(protT->at(i++))); - cenv.vals.def(dynamic_cast<ASymbol*>(*p), &*a); - } + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) + cenv.def(checked_cast<ASymbol*>(*p), *p, checked_cast<AType*>(protT->at(i++)), &*a); // Write function body try { @@ -240,65 +249,71 @@ AClosure::lift(CEnv& cenv) } void -AClosure::liftPoly(CEnv& cenv, const vector<AType*>& argsT) +AClosure::liftCall(CEnv& cenv, const vector<AType*>& argsT) { - if (type->concrete()) - return; - - throw Error("No polymorphism"); + TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(this); + assert(gt != cenv.tenv.genericTypes.end()); + AType* genericType = gt->second; -#if 0 - //Subst tsubst; + // Find type and build substitution assert(argsT.size() == prot()->size()); - for (size_t i = 0; i < argsT.size(); ++i) { - cenv.err << " " << argsT.at(i)->str(); - //tsubst[*cenv.tenv.ref(prot()->at(i))] = argsT.at(i); - } - cenv.err << endl; -#endif -} + Subst argsSubst; + ATuple* genericProtT = dynamic_cast<ATuple*>(genericType->at(1)); + assert(genericProtT); + for (size_t i = 0; i < argsT.size(); ++i) + argsSubst[dynamic_cast<AType*>(genericProtT->at(i))] = dynamic_cast<AType*>(argsT.at(i)); + + AType* thisType = new AType(*dynamic_cast<ATuple*>(argsSubst.apply(genericType)), loc); + cenv.err << "THIS TYPE: " << thisType << endl; -CValue -AClosure::compile(CEnv& cenv) -{ - /* - cenv.err << "***********************************************" << endl; - cenv.err << cenv.type(this) << endl; + //if (!thisType->concrete()) + // throw Error("unable to resolve concrete type for function", loc); - cenv.err << "COMPILING FOR TYPE:"; - Subst tsubst; - assert(cenv.code.front().size() == prot()->size()); - for (size_t i = 0; i < cenv.code.front().size(); ++i) { - cenv.err << " (" << cenv.type(prot()->at(i))->str() - << " -> " << cenv.type(cenv.code.front().at(i).second)->str() << ")"; - tsubst[cenv.tenv.types[prot()->at(i)]] = - cenv.type(cenv.code.front().at(i).second); - } - cenv.err << endl; + if (funcs.find(thisType)) + return; - Subst subst = Subst::compose(tsubst, cenv.tsubst); - AType* concreteType = subst.apply(type); - if (!concreteType->concrete()) - throw Error("compiled function has non-concrete type", loc); + ATuple* protT = dynamic_cast<ATuple*>(thisType->at(1)); - cenv.err << "*********** CONCRETE TYPE: " << concreteType->str() << endl; - */ + // Write function declaration + string name = this->name == "" ? cenv.gensym("_fn") : this->name; + Function* f = compileFunction(cenv, name, + lltype(dynamic_cast<AType*>(thisType->at(thisType->size() - 1))), + *protT); - //CValue ret = funcs.find(concreteType); - //cenv.err << "VALUE FOR TYPE " << concreteType->str() << " : " << ret << endl; - //return ret; - return NULL; + cenv.push(); + Subst oldSubst = cenv.tsubst; + cenv.tsubst = Subst::compose(cenv.tsubst, Subst::compose(argsSubst, *subst)); + + // Bind argument values in CEnv + vector<Value*> args; + const_iterator p = prot()->begin(); + size_t i = 0; + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) + cenv.def(checked_cast<ASymbol*>(*p), *p, checked_cast<AType*>(protT->at(i++)), &*a); + + // Write function body + try { + // Define value first for recursion + cenv.precompile(this, f); + funcs.push_back(make_pair(thisType, f)); + + CValue retVal = cenv.compile(at(2)); + llengine(cenv)->builder.CreateRet(LLVal(retVal)); // Finish function + cenv.optimise(LLFunc(f)); + + } catch (Error& e) { + f->eraseFromParent(); // Error reading body, remove function + cenv.pop(); + throw e; + } + cenv.tsubst = oldSubst; + cenv.pop(); } -template<typename T> -T -checked_cast(AST* ast) +CValue +AClosure::compile(CEnv& cenv) { - T t = dynamic_cast<T>(ast); - if (!t) - throw Error((format("internal error: `%1%' should be a `%2%'") - % typeid(ast).name() % typeid(T).name()).str(), ast->loc); - return t; + return NULL; } static @@ -306,7 +321,7 @@ AST* maybeLookup(CEnv& cenv, AST* ast) { ASymbol* s = dynamic_cast<ASymbol*>(ast); - if (s) + if (s && s->addr) return cenv.code.deref(s->addr); return ast; } @@ -325,18 +340,19 @@ ACall::lift(CEnv& cenv) } if (!c) return; // Primitive - - // Extend environment with bound and typed parameters - cenv.push(); + if (c->prot()->size() < size() - 1) throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc); if (c->prot()->size() > size() - 1) throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc); + // Extend environment with bound and typed parameters + cenv.push(); + for (size_t i = 1; i < size(); ++i) - cenv.code.def(checked_cast<ASymbol*>(c->prot()->at(i-1)), at(i)); + cenv.def(checked_cast<ASymbol*>(c->prot()->at(i-1)), at(i), cenv.type(at(i)), NULL); - c->liftPoly(cenv, argsT); // Lift called closure + c->liftCall(cenv, argsT); // Lift called closure cenv.pop(); // Restore environment } @@ -347,16 +363,13 @@ ACall::compile(CEnv& cenv) if (!c) return NULL; // Primitive - if (c->prot()->size() < size() - 1) - throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc); - if (c->prot()->size() > size() - 1) - throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc); - AType* protT = new AType(loc, NULL); for (size_t i = 1; i < size(); ++i) protT->push_back(cenv.type(at(i))); - - AType* polyT = c->type; + + TEnv::GenericTypes::const_iterator gt = cenv.tenv.genericTypes.find(c); + assert(gt != cenv.tenv.genericTypes.end()); + AType* polyT = gt->second; AType* fnT = new AType(loc, cenv.penv.sym("Fn"), protT, polyT->at(2), 0); Function* f = (Function*)c->funcs.find(fnT); @@ -374,7 +387,8 @@ ADefinition::lift(CEnv& cenv) { if (cenv.code.lookup(checked_cast<ASymbol*>(at(1)))) throw Error(string("`") + at(1)->str() + "' redefined", loc); - cenv.code.def((ASymbol*)at(1), at(2)); // Define first for recursion + // Define first for recursion + cenv.def(checked_cast<ASymbol*>(at(1)), at(2), cenv.type(at(2)), NULL); at(2)->lift(cenv); } @@ -630,11 +644,11 @@ eval(CEnv& cenv, const string& name, istream& is) int repl(CEnv& cenv) { - Constraints c; while (1) { cenv.out << "() "; cenv.out.flush(); Cursor cursor("(stdin)"); + Constraints c; try { SExp exp = readExpression(cursor, std::cin); @@ -643,6 +657,10 @@ repl(CEnv& cenv) AST* body = cenv.penv.parse(exp); // Parse input body->constrain(cenv.tenv, c); // Constrain types + + for (TEnv::GenericTypes::const_iterator i = cenv.tenv.genericTypes.begin(); + i != cenv.tenv.genericTypes.end(); ++i) + c.push_back(Constraint(cenv.tenv.var(i->first), i->second, i->first->loc)); cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints @@ -84,8 +84,8 @@ template<typename K, typename V> struct Env : public list< vector< pair<K,V> > > { typedef vector< pair<K,V> > Frame; Env() : list<Frame>(1) {} - void push(Frame f=Frame()) { list<Frame>::push_front(f); } - void pop() { assert(!this->empty()); list<Frame>::pop_front(); } + virtual void push(Frame f=Frame()) { list<Frame>::push_front(f); } + virtual void pop() { assert(!this->empty()); list<Frame>::pop_front(); } const V& def(const K& k, const V& v) { for (typename Frame::iterator b = this->begin()->begin(); b != this->begin()->end(); ++b) if (b->first == k) @@ -112,6 +112,9 @@ struct Env : public list< vector< pair<K,V> > > { assert(addr); typename Env::iterator f = this->begin(); for (unsigned u = 1; u < addr.up; ++u, ++f) { assert(f != this->end()); } + if (!(f->size() > addr.over - 1)) { + std::cerr << "WTF: " << addr << " : " << this->size() << "." << f->size() << endl; + } assert(f->size() > addr.over - 1); return (*f)[addr.over - 1].second; } @@ -155,7 +158,7 @@ struct AST { virtual ~AST() {} virtual bool operator==(const AST& o) const = 0; virtual bool contains(const AST* child) const { return false; } - virtual void constrain(TEnv& tenv, Constraints& c) {} + virtual void constrain(TEnv& tenv, Constraints& c) const {} virtual void lift(CEnv& cenv) {} string str() const { ostringstream ss; ss << this; return ss.str(); } Cursor loc; @@ -172,7 +175,7 @@ struct ALiteral : public AST { const ALiteral<VT>* r = dynamic_cast<const ALiteral<VT>*>(&rhs); return (r && (val == r->val)); } - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; CValue compile(CEnv& cenv); const VT val; }; @@ -180,10 +183,9 @@ struct ALiteral : public AST { /// Symbol, e.g. "a" struct ASymbol : public AST { bool operator==(const AST& rhs) const { return this == &rhs; } - void lookup(TEnv& tenv); - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; CValue compile(CEnv& cenv); - LAddr addr; + mutable LAddr addr; private: friend class PEnv; ASymbol(const string& s, Cursor c) : AST(c), cppstr(s) {} @@ -225,7 +227,7 @@ struct ATuple : public AST, public vector<AST*> { return true; return false; } - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; CValue compile(CEnv& cenv) { throw Error("tuple compiled"); } }; @@ -287,16 +289,15 @@ struct Funcs : public list< pair<AType*, CFunction> > { /// Closure (first-class function with captured lexical bindings) struct AClosure : public ATuple { AClosure(Cursor c, ASymbol* fn, ATuple* p, const string& n="") - : ATuple(c, fn, p, NULL), type(0), subst(0), name(n) {} + : ATuple(c, fn, p, NULL), subst(0), name(n) {} bool operator==(const AST& rhs) const { return this == &rhs; } - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; void lift(CEnv& cenv); - void liftPoly(CEnv& cenv, const vector<AType*>& argsT); + void liftCall(CEnv& cenv, const vector<AType*>& argsT); CValue compile(CEnv& cenv); ATuple* prot() const { return dynamic_cast<ATuple*>(at(1)); } - AType* type; - Funcs funcs; - Subst* subst; + Funcs funcs; + mutable Subst* subst; private: string name; }; @@ -304,7 +305,7 @@ private: /// Function call/application, e.g. "(func arg1 arg2)" struct ACall : public ATuple { ACall(const SExp& e, const ATuple& t) : ATuple(t, e.loc) {} - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; void lift(CEnv& cenv); CValue compile(CEnv& cenv); }; @@ -313,7 +314,7 @@ struct ACall : public ATuple { struct ADefinition : public ACall { ADefinition(const SExp& e, const ATuple& t) : ACall(e, t) {} ASymbol* sym() const { return dynamic_cast<ASymbol*>(at(1)); } - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; void lift(CEnv& cenv); CValue compile(CEnv& cenv); }; @@ -321,14 +322,14 @@ struct ADefinition : public ACall { /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct AIf : public ACall { AIf(const SExp& e, const ATuple& t) : ACall(e, t) {} - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; CValue compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct APrimitive : public ACall { APrimitive(const SExp& e, const ATuple& t) : ACall(e, t) {} - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; CValue compile(CEnv& cenv); }; @@ -336,7 +337,7 @@ struct APrimitive : public ACall { struct AConsCall : public ACall { AConsCall(const SExp& e, const ATuple& t) : ACall(e, t) {} AType* functionType(CEnv& cenv); - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; void lift(CEnv& cenv); CValue compile(CEnv& cenv); static Funcs funcs; @@ -345,14 +346,14 @@ struct AConsCall : public ACall { /// Car special form, e.g. "(car p)" struct ACarCall : public ACall { ACarCall(const SExp& e, const ATuple& t) : ACall(e, t) {} - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; CValue compile(CEnv& cenv); }; /// Cdr special form, e.g. "(cdr p)" struct ACdrCall : public ACall { ACdrCall(const SExp& e, const ATuple& t) : ACall(e, t) {} - void constrain(TEnv& tenv, Constraints& c); + void constrain(TEnv& tenv, Constraints& c) const; CValue compile(CEnv& cenv); }; @@ -432,7 +433,7 @@ inline ostream& operator<<(ostream& out, const Constraints& c) { } struct Subst : public map<const AType*,AType*> { - Subst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); } + Subst(AType* s=0, AType* t=0) { if (s && t) { assert(s != t); insert(make_pair(s, t)); } } static Subst compose(const Subst& delta, const Subst& gamma); AST* apply(AST* ast) const { AType* in = dynamic_cast<AType*>(ast); @@ -443,9 +444,13 @@ struct Subst : public map<const AType*,AType*> { out->push_back(apply(in->at(i))); return out; } else { - const_iterator i; - while ((i = find(in)) != end()) + Subst copy(*this); + iterator i; + while ((i = copy.find(in)) != copy.end()) { + cerr << "IN: " << in << endl; in = i->second; + copy.erase(i); + } return in; } } @@ -460,11 +465,15 @@ struct TEnv : public Env<const AST*,AType*> { return def(sym, new AType(varID++, LAddr(), sym->loc)); } AType* var(const AST* ast=0) { + /*GenericTypes::iterator g = genericTypes.find(dynamic_casdt<AClosure*>(ast)); + if (g != vars.end()) + return g->second;*/ + const ASymbol* sym = dynamic_cast<const ASymbol*>(ast); if (sym) return deref(lookup(sym)); - map<const AST*, AType*>::iterator v = vars.find(ast); + Vars::iterator v = vars.find(ast); if (v != vars.end()) return v->second; @@ -479,10 +488,13 @@ struct TEnv : public Env<const AST*,AType*> { } static Subst unify(const Constraints& c); - map<const AST*, AType*> vars; - PEnv& penv; - Constraints constraints; - unsigned varID; + typedef map<const AST*, AType*> Vars; + typedef map<const AClosure*, AType*> GenericTypes; + Vars vars; + GenericTypes genericTypes; + PEnv& penv; + Constraints constraints; + unsigned varID; }; @@ -512,17 +524,23 @@ struct CEnv { return tenv.deref(sym->addr); return dynamic_cast<AType*>(tsubst.apply(subst.apply(tenv.vars[ast]))); } + void def(ASymbol* sym, AST* c, AType* t, CValue v) { + code.def(sym, c); + tenv.def(sym, t); + vals.def(sym, v); + } - ostream& out; - ostream& err; - PEnv& penv; - TEnv& tenv; - Code code; - Vals vals; - - unsigned symID; - CFunction alloc; - Subst tsubst; + ostream& out; + ostream& err; + PEnv& penv; + TEnv& tenv; + Code code; + Vals vals; + + + unsigned symID; + CFunction alloc; + Subst tsubst; private: struct PImpl; ///< Private Implementation @@ -30,28 +30,24 @@ Constraints::constrain(TEnv& tenv, const AST* o, AType* t) ***************************************************************************/ void -ASymbol::lookup(TEnv& tenv) +ASymbol::constrain(TEnv& tenv, Constraints& c) const { addr = tenv.lookup(this); if (!addr) throw Error((format("undefined symbol `%1%'") % cppstr).str(), loc); -} - -void -ASymbol::constrain(TEnv& tenv, Constraints& c) -{ - lookup(tenv); AType* t = tenv.deref(addr); if (!t) throw Error((format("unresolved symbol `%1%'") % cppstr).str(), loc); - c.push_back(Constraint(tenv.var(this), tenv.deref(addr), loc)); + AType* var = tenv.deref(addr); + var->addr = addr; + c.push_back(Constraint(tenv.var(this), var, loc)); } void -ATuple::constrain(TEnv& tenv, Constraints& c) +ATuple::constrain(TEnv& tenv, Constraints& c) const { AType* t = new AType(ATuple(), loc); - FOREACH(iterator, p, *this) { + FOREACH(const_iterator, p, *this) { (*p)->constrain(tenv, c); t->push_back(tenv.var(*p)); } @@ -59,95 +55,105 @@ ATuple::constrain(TEnv& tenv, Constraints& c) } void -AClosure::constrain(TEnv& tenv, Constraints& c) +AClosure::constrain(TEnv& tenv, Constraints& c) const { - set<ASymbol*> defined; - TEnv::Frame frame; - - // Add parameters to environment frame - for (size_t i = 0; i < prot()->size(); ++i) { - ASymbol* sym = dynamic_cast<ASymbol*>(prot()->at(i)); - if (!sym) - throw Error("parameter name is not a symbol", prot()->at(i)->loc); - if (defined.find(sym) != defined.end()) - throw Error((format("duplicate parameter `%1%'") % sym->str()).str(), sym->loc); - defined.insert(sym); - frame.push_back(make_pair(sym, (AType*)NULL)); - } - - // Add internal definitions to environment frame - size_t e = 2; - for (; e < size(); ++e) { - AST* exp = at(e); - ADefinition* def = dynamic_cast<ADefinition*>(exp); - if (def) { - ASymbol* sym = def->sym(); + AType* genericType; + TEnv::GenericTypes::const_iterator gt = tenv.genericTypes.find(this); + if (gt != tenv.genericTypes.end()) { + genericType = gt->second; + } else { + set<ASymbol*> defined; + TEnv::Frame frame; + + // Add parameters to environment frame + for (size_t i = 0; i < prot()->size(); ++i) { + ASymbol* sym = dynamic_cast<ASymbol*>(prot()->at(i)); + if (!sym) + throw Error("parameter name is not a symbol", prot()->at(i)->loc); if (defined.find(sym) != defined.end()) - throw Error((format("`%1%' defined twice") % sym->str()).str(), def->loc); - defined.insert(def->sym()); - frame.push_back(make_pair(def->sym(), (AType*)NULL)); + throw Error((format("duplicate parameter `%1%'") % sym->str()).str(), sym->loc); + defined.insert(sym); + frame.push_back(make_pair(sym, (AType*)NULL)); } - } - - tenv.push(frame); - - Constraints cp; - cp.push_back(Constraint(tenv.var(this), tenv.var(), loc)); - - AType* protT = new AType(ATuple(), loc); - for (size_t i = 0; i < prot()->size(); ++i) { - AType* tvar = tenv.fresh(dynamic_cast<ASymbol*>(prot()->at(i))); - protT->push_back(tvar); - assert(frame[i].first == prot()->at(i)); - frame[i].second = tvar; - } - c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc)); - for (size_t i = 2; i < size(); ++i) - at(i)->constrain(tenv, cp); + // Add internal definitions to environment frame + size_t e = 2; + for (; e < size(); ++e) { + AST* exp = at(e); + ADefinition* def = dynamic_cast<ADefinition*>(exp); + if (def) { + ASymbol* sym = def->sym(); + if (defined.find(sym) != defined.end()) + throw Error((format("`%1%' defined twice") % sym->str()).str(), def->loc); + defined.insert(def->sym()); + frame.push_back(make_pair(def->sym(), (AType*)NULL)); + } + } - AType* bodyT = tenv.var(at(e-1)); - Subst tsubst = TEnv::unify(cp); - type = new AType(loc, tenv.penv.sym("Fn"), tsubst.apply(protT), tsubst.apply(bodyT), 0); + tenv.push(frame); - tenv.pop(); + Constraints cp; + cp.push_back(Constraint(tenv.var(this), tenv.var(), loc)); - c.constrain(tenv, this, type); - subst = new Subst(tsubst); + AType* protT = new AType(ATuple(), loc); + for (size_t i = 0; i < prot()->size(); ++i) { + AType* tvar = tenv.fresh(dynamic_cast<ASymbol*>(prot()->at(i))); + protT->push_back(tvar); + assert(frame[i].first == prot()->at(i)); + frame[i].second = tvar; + } + c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc)); + + for (size_t i = 2; i < size(); ++i) + at(i)->constrain(tenv, cp); + + AType* bodyT = tenv.var(at(e-1)); + Subst tsubst = TEnv::unify(cp); + genericType = new AType(loc, tenv.penv.sym("Fn"), + tsubst.apply(protT), tsubst.apply(bodyT), 0); + tenv.genericTypes.insert(make_pair(this, genericType)); + tenv.def(this, genericType); + + tenv.pop(); + subst = new Subst(tsubst); + } + + c.constrain(tenv, this, genericType); + //for (Constraints::const_iterator i = cp.begin(); i != cp.end(); ++i) + // c.push_back(*i); } void -ACall::constrain(TEnv& tenv, Constraints& c) +ACall::constrain(TEnv& tenv, Constraints& c) const { + std::cerr << "CONSTRAIN CALL" << endl; + + at(0)->constrain(tenv, c); AType* argsT = new AType(ATuple(), loc); - TEnv::Frame frame; for (size_t i = 1; i < size(); ++i) { at(i)->constrain(tenv, c); argsT->push_back(tenv.var(at(i))); - frame.push_back(make_pair((AST*)NULL, tenv.var(at(i)))); } AType* retT = tenv.var(); - - at(0)->constrain(tenv, c); c.constrain(tenv, at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0)); c.constrain(tenv, this, retT); } void -ADefinition::constrain(TEnv& tenv, Constraints& c) +ADefinition::constrain(TEnv& tenv, Constraints& c) const { if (size() != 3) throw Error("`def' requires exactly 2 arguments", loc); if (!dynamic_cast<const ASymbol*>(at(1))) throw Error("`def' name is not a symbol", loc); - AType* tvar = tenv.var(this); + AType* tvar = tenv.var(at(2)); tenv.def(at(1), tvar); at(2)->constrain(tenv, c); - c.constrain(tenv, at(2), tvar); + c.constrain(tenv, this, tvar); } void -AIf::constrain(TEnv& tenv, Constraints& c) +AIf::constrain(TEnv& tenv, Constraints& c) const { if (size() < 3) throw Error("`if' requires exactly 3 arguments", loc); if (size() % 2 != 0) throw Error("`if' missing final else clause", loc); @@ -165,7 +171,7 @@ AIf::constrain(TEnv& tenv, Constraints& c) } void -APrimitive::constrain(TEnv& tenv, Constraints& c) +APrimitive::constrain(TEnv& tenv, Constraints& c) const { const string n = dynamic_cast<ASymbol*>(at(0))->str(); enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type; @@ -215,7 +221,7 @@ APrimitive::constrain(TEnv& tenv, Constraints& c) } void -AConsCall::constrain(TEnv& tenv, Constraints& c) +AConsCall::constrain(TEnv& tenv, Constraints& c) const { if (size() != 3) throw Error("`cons' requires exactly 2 arguments", loc); AType* t = new AType(loc, tenv.penv.sym("Pair"), 0); @@ -227,7 +233,7 @@ AConsCall::constrain(TEnv& tenv, Constraints& c) } void -ACarCall::constrain(TEnv& tenv, Constraints& c) +ACarCall::constrain(TEnv& tenv, Constraints& c) const { if (size() != 2) throw Error("`car' requires exactly 1 argument", loc); at(1)->constrain(tenv, c); @@ -238,7 +244,7 @@ ACarCall::constrain(TEnv& tenv, Constraints& c) } void -ACdrCall::constrain(TEnv& tenv, Constraints& c) +ACdrCall::constrain(TEnv& tenv, Constraints& c) const { if (size() != 2) throw Error("`cdr' requires exactly 1 argument", loc); at(1)->constrain(tenv, c); |