diff options
-rw-r--r-- | llvm.cpp | 179 | ||||
-rw-r--r-- | tuplr.cpp | 12 | ||||
-rw-r--r-- | tuplr.hpp | 210 | ||||
-rw-r--r-- | typing.cpp | 231 |
4 files changed, 432 insertions, 200 deletions
@@ -46,10 +46,11 @@ struct LLVMEngine { }; static const Type* -lltype(AType* t) +lltype(const AType* t) { switch (t->kind) { case AType::VAR: + throw Error((format("non-compilable type `%1%'") % t->str()).str(), t->loc); return NULL; case AType::PRIM: if (t->at(0)->str() == "Bool") return Type::Int1Ty; @@ -117,7 +118,7 @@ CValue CEnv::compile(AST* obj) { CValue* v = vals.ref(obj); - return (v) ? *v : vals.def(obj, obj->compile(*this)); + return (v && *v) ? *v : vals.def(obj, obj->compile(*this)); } void @@ -138,7 +139,7 @@ CEnv::write(std::ostream& os) template<> CValue \ ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \ template<> void \ -ALiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } +ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) { c.constrain(tenv, this, tenv.named(NAME)); } /// Literal template instantiations LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)) @@ -146,14 +147,15 @@ LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)) LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)) static Function* -compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& prot, +compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& protT, const vector<string> argNames=vector<string>()) { Function::LinkageTypes linkage = Function::ExternalLinkage; vector<const Type*> cprot; - for (size_t i = 0; i < prot.size(); ++i) { - AType* at = cenv.tenv.type(prot.at(i)); + for (size_t i = 0; i < protT.size(); ++i) { + AType* at = dynamic_cast<AType*>(protT.at(i)); + if (!at) throw Error("function parameter type isn't"); if (!lltype(at)) throw Error("function parameter is untyped"); cprot.push_back(lltype(at)); } @@ -170,11 +172,8 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); if (!argNames.empty()) - for (size_t i = 0; i != prot.size(); ++a, ++i) + for (size_t i = 0; i != protT.size(); ++a, ++i) a->setName(argNames.at(i)); - else - for (size_t i = 0; i != prot.size(); ++a, ++i) - a->setName(prot.at(i)->str()); BasicBlock* bb = BasicBlock::Create("entry", f); llengine(cenv)->builder.SetInsertPoint(bb); @@ -187,53 +186,110 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu * Code Generation * ***************************************************************************/ -void -ASymbol::lift(CEnv& cenv) -{ - if (!cenv.code.ref(this)) - throw Error((string("undefined symbol `") + cppstr + "'").c_str(), loc); -} - CValue ASymbol::compile(CEnv& cenv) { - return cenv.compile(*cenv.code.ref(this)); + return cenv.vals.ref(this); } void AClosure::lift(CEnv& cenv) { - AType* type = cenv.tenv.type(this); + AType* type = cenv.type(this); if (!type->concrete() || funcs.find(type)) return; - cenv.push(); - // Write function declaration string name = this->name == "" ? cenv.gensym("_fn") : this->name; - Function* f = compileFunction(cenv, name, lltype(cenv.tenv.type(at(2))), *prot()); + ATuple* protT = dynamic_cast<ATuple*>(type->at(1)); + assert(protT); + Function* f = compileFunction(cenv, name, + lltype(dynamic_cast<AType*>(type->at(type->size() - 1))), + *protT); + + cenv.push(); + Subst oldSubst = cenv.tsubst; + cenv.tsubst = Subst::compose(cenv.tsubst, *subst); // Bind argument values in CEnv vector<Value*> args; const_iterator p = prot()->begin(); - for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) + size_t i = 0; + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) { + cenv.tenv.def(*p, dynamic_cast<AType*>(protT->at(i++))); cenv.vals.def(dynamic_cast<ASymbol*>(*p), &*a); + } // Write function body try { - cenv.precompile(this, f); // Define our value first for recursion + // Define value first for recursion + cenv.precompile(this, f); + funcs.push_back(make_pair(type, f)); + CValue retVal = cenv.compile(at(2)); llengine(cenv)->builder.CreateRet(LLVal(retVal)); // Finish function cenv.optimise(LLFunc(f)); - funcs.push_back(make_pair(type, f)); + } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function + cenv.pop(); throw e; } - + cenv.tsubst = oldSubst; cenv.pop(); } +void +AClosure::liftPoly(CEnv& cenv, const vector<AType*>& argsT) +{ + if (type->concrete()) + return; + + throw Error("No polymorphism"); + +#if 0 + //Subst tsubst; + assert(argsT.size() == prot()->size()); + for (size_t i = 0; i < argsT.size(); ++i) { + cenv.err << " " << argsT.at(i)->str(); + //tsubst[*cenv.tenv.ref(prot()->at(i))] = argsT.at(i); + } + cenv.err << endl; +#endif +} + +CValue +AClosure::compile(CEnv& cenv) +{ + /* + cenv.err << "***********************************************" << endl; + cenv.err << cenv.type(this) << endl; + + cenv.err << "COMPILING FOR TYPE:"; + Subst tsubst; + assert(cenv.code.front().size() == prot()->size()); + for (size_t i = 0; i < cenv.code.front().size(); ++i) { + cenv.err << " (" << cenv.type(prot()->at(i))->str() + << " -> " << cenv.type(cenv.code.front().at(i).second)->str() << ")"; + tsubst[cenv.tenv.types[prot()->at(i)]] = + cenv.type(cenv.code.front().at(i).second); + } + cenv.err << endl; + + Subst subst = Subst::compose(tsubst, cenv.tsubst); + AType* concreteType = subst.apply(type); + if (!concreteType->concrete()) + throw Error("compiled function has non-concrete type", loc); + + cenv.err << "*********** CONCRETE TYPE: " << concreteType->str() << endl; + */ + + //CValue ret = funcs.find(concreteType); + //cenv.err << "VALUE FOR TYPE " << concreteType->str() << " : " << ret << endl; + //return ret; + return NULL; +} + template<typename T> T checked_cast(AST* ast) @@ -250,27 +306,23 @@ AST* maybeLookup(CEnv& cenv, AST* ast) { ASymbol* s = dynamic_cast<ASymbol*>(ast); - if (s) { - AST** val = cenv.code.ref(s); - if (val) return *val; - } + if (s) + return cenv.code.deref(s->addr); return ast; } -CValue -AClosure::compile(CEnv& cenv) -{ - return funcs.find(cenv.tenv.type(this)); -} - void ACall::lift(CEnv& cenv) { AClosure* c = dynamic_cast<AClosure*>(maybeLookup(cenv, at(0))); + vector<AType*> argsT; + // Lift arguments - for (size_t i = 1; i < size(); ++i) + for (size_t i = 1; i < size(); ++i) { at(i)->lift(cenv); + argsT.push_back(cenv.type(at(i))); + } if (!c) return; // Primitive @@ -284,15 +336,30 @@ ACall::lift(CEnv& cenv) for (size_t i = 1; i < size(); ++i) cenv.code.def(checked_cast<ASymbol*>(c->prot()->at(i-1)), at(i)); - c->lift(cenv); // Lift called closure + c->liftPoly(cenv, argsT); // Lift called closure cenv.pop(); // Restore environment } CValue ACall::compile(CEnv& cenv) { - AST* c = maybeLookup(cenv, at(0)); - Function* f = dynamic_cast<Function*>(LLVal(cenv.compile(c))); + AClosure* c = dynamic_cast<AClosure*>(maybeLookup(cenv, at(0))); + + if (!c) return NULL; // Primitive + + if (c->prot()->size() < size() - 1) + throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc); + if (c->prot()->size() > size() - 1) + throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc); + + AType* protT = new AType(loc, NULL); + for (size_t i = 1; i < size(); ++i) + protT->push_back(cenv.type(at(i))); + + AType* polyT = c->type; + AType* fnT = new AType(loc, cenv.penv.sym("Fn"), protT, polyT->at(2), 0); + + Function* f = (Function*)c->funcs.find(fnT); if (!f) throw Error("callee failed to compile", loc); vector<Value*> params(size() - 1); @@ -305,7 +372,7 @@ ACall::compile(CEnv& cenv) void ADefinition::lift(CEnv& cenv) { - if (cenv.code.ref(checked_cast<ASymbol*>(at(1)))) + if (cenv.code.lookup(checked_cast<ASymbol*>(at(1)))) throw Error(string("`") + at(1)->str() + "' redefined", loc); cenv.code.def((ASymbol*)at(1), at(2)); // Define first for recursion at(2)->lift(cenv); @@ -353,7 +420,7 @@ AIf::compile(CEnv& cenv) // Emit merge block (Phi node) parent->getBasicBlockList().push_back(mergeBB); llengine(cenv)->builder.SetInsertPoint(mergeBB); - PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.tenv.type(this)), "ifval"); + PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.type(this)), "ifval"); FOREACH(Branches::iterator, i, branches) pn->addIncoming(i->first, i->second); @@ -366,7 +433,7 @@ APrimitive::compile(CEnv& cenv) { Value* a = LLVal(cenv.compile(at(1))); Value* b = LLVal(cenv.compile(at(2))); - bool isInt = cenv.tenv.type(at(1))->str() == "Int"; + bool isInt = cenv.type(at(1))->str() == "Int"; const string n = dynamic_cast<ASymbol*>(at(0))->str(); // Binary arithmetic operations @@ -407,9 +474,9 @@ APrimitive::compile(CEnv& cenv) AType* AConsCall::functionType(CEnv& cenv) { - ATuple* protTypes = new ATuple(loc, cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), 0); + ATuple* protTypes = new ATuple(loc, cenv.type(at(1)), cenv.type(at(2)), 0); AType* cellType = new AType(loc, - cenv.penv.sym("Pair"), cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), 0); + cenv.penv.sym("Pair"), cenv.type(at(1)), cenv.type(at(2)), 0); return new AType(at(0)->loc, cenv.penv.sym("Fn"), protTypes, cellType, 0); } @@ -427,7 +494,7 @@ AConsCall::lift(CEnv& cenv) vector<const Type*> types; size_t sz = 0; for (size_t i = 1; i < size(); ++i) { - const Type* t = lltype(cenv.tenv.type(at(i))); + const Type* t = lltype(cenv.type(at(i))); types.push_back(t); sz += t->getPrimitiveSizeInBits(); } @@ -520,15 +587,16 @@ eval(CEnv& cenv, const string& name, istream& is) list< pair<SExp, AST*> > exprs; Cursor cursor(name); try { + Constraints c; while (true) { SExp exp = readExpression(cursor, is); if (exp.type == SExp::LIST && exp.list.empty()) break; result = cenv.penv.parse(exp); // Parse input - result->constrain(cenv.tenv); // Constrain types - cenv.tenv.solve(); // Solve and apply type constraints - resultType = cenv.tenv.type(result); + result->constrain(cenv.tenv, c); // Constrain types + cenv.tsubst = TEnv::unify(c); // Solve type constraints + resultType = cenv.type(result); result->lift(cenv); // Lift functions exprs.push_back(make_pair(exp, result)); } @@ -562,24 +630,27 @@ eval(CEnv& cenv, const string& name, istream& is) int repl(CEnv& cenv) { + Constraints c; while (1) { cenv.out << "() "; cenv.out.flush(); Cursor cursor("(stdin)"); try { + SExp exp = readExpression(cursor, std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; AST* body = cenv.penv.parse(exp); // Parse input - body->constrain(cenv.tenv); // Constrain types - cenv.tenv.solve(); // Solve and apply type constraints + body->constrain(cenv.tenv, c); // Constrain types + + cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints - AType* bodyT = cenv.tenv.type(body); + AType* bodyT = cenv.type(body); if (!bodyT) throw Error("call to untyped body", cursor); body->lift(cenv); - + if (lltype(bodyT)) { // Create anonymous function to insert code into Function* f = compileFunction(cenv, cenv.gensym("_repl"), lltype(bodyT), ATuple()); @@ -595,7 +666,7 @@ repl(CEnv& cenv) } else { cenv.out << "; " << cenv.compile(body); } - cenv.out << " : " << cenv.tenv.type(body) << endl; + cenv.out << " : " << cenv.type(body) << endl; } catch (Error& e) { cenv.err << e.what() << endl; } @@ -123,10 +123,15 @@ parseLiteral(PEnv& penv, const SExp& exp, void* arg) inline AST* parseFn(PEnv& penv, const SExp& exp, void* arg) { + if (exp.list.size() < 2) + throw Error("Missing function parameters and body", exp.loc); + else if (exp.list.size() < 3) + throw Error("Missing function body", exp.loc); SExp::List::const_iterator a = exp.list.begin(); ++a; - return new AClosure(exp.loc, penv.sym("fn"), - new ATuple(penv.parseTuple(*a++)), - penv.parse(*a++)); + AClosure* ret = new AClosure(exp.loc, penv.sym("fn"), new ATuple(penv.parseTuple(*a++))); + while (a != exp.list.end()) + ret->push_back(penv.parse(*a++)); + return ret; } @@ -200,6 +205,7 @@ main(int argc, char** argv) initLang(penv, tenv); CEnv* cenv = newCenv(penv, tenv); + cenv->push(); map<string,string> args; list<string> files; @@ -67,25 +67,53 @@ struct Exp { List list; }; +/// Lexical Address +struct LAddr { + LAddr(unsigned u=0, unsigned o=0) : up(u), over(o) {} + operator bool() const { return !(up == 0 && over == 0); } + unsigned up, over; +}; + +inline ostream& operator<<(ostream& out, const LAddr& addr) { + out << addr.up << ":" << addr.over; + return out; +} + /// Generic Lexical Environment template<typename K, typename V> -struct Env : public list< map<K,V> > { - typedef map<K,V> Frame; +struct Env : public list< vector< pair<K,V> > > { + typedef vector< pair<K,V> > Frame; Env() : list<Frame>(1) {} - void push() { list<Frame>::push_front(Frame()); } - void pop() { assert(!this->empty()); list<Frame>::pop_front(); } + void push(Frame f=Frame()) { list<Frame>::push_front(f); } + void pop() { assert(!this->empty()); list<Frame>::pop_front(); } const V& def(const K& k, const V& v) { - typename Frame::iterator existing = this->front().find(k); - if (existing != this->front().end() && existing->second != v) - throw Error("redefinition"); - return (this->front()[k] = v); + for (typename Frame::iterator b = this->begin()->begin(); b != this->begin()->end(); ++b) + if (b->first == k) + return (b->second = v); + this->front().push_back(make_pair(k, v)); + return v; + } + V* ref(const K& key) { + for (typename Env::iterator f = this->begin(); f != this->end(); ++f) + for (typename Frame::iterator b = f->begin(); b != f->end(); ++b) + if (b->first == key) + return &b->second; + return NULL; } - V* ref(const K& name) { - typename Frame::iterator s; - for (typename Env::iterator i = this->begin(); i != this->end(); ++i) - if ((s = i ->find(name)) != i->end()) - return &s->second; - return 0; + LAddr lookup(const K& key) const { + unsigned up = 0; + for (typename Env::const_iterator f = this->begin(); f != this->end(); ++f, ++up) + for (unsigned over = 0; over < f->size(); ++over) + if ((*f)[over].first == key) + return LAddr(up + 1, over + 1); + return LAddr(); + } + V& deref(LAddr addr) { + assert(addr); + typename Env::iterator f = this->begin(); + for (unsigned u = 1; u < addr.up; ++u, ++f) { assert(f != this->end()); } + assert(f->size() > addr.over - 1); + return (*f)[addr.over - 1].second; } }; @@ -112,10 +140,13 @@ typedef void* CEngine; ///< Compiler Engine (opaque) * Abstract Syntax Tree * ***************************************************************************/ -struct TEnv; ///< Type-Time Environment -struct CEnv; ///< Compile-Time Environment - +struct Constraint; ///< Type Constraint +struct TEnv; ///< Type-Time Environment +struct CEnv; ///< Compile-Time Environment struct AST; +struct Constraints; +struct Subst; + extern ostream& operator<<(ostream& out, const AST* ast); /// Base class for all AST nodes @@ -124,7 +155,7 @@ struct AST { virtual ~AST() {} virtual bool operator==(const AST& o) const = 0; virtual bool contains(const AST* child) const { return false; } - virtual void constrain(TEnv& tenv) const {} + virtual void constrain(TEnv& tenv, Constraints& c) {} virtual void lift(CEnv& cenv) {} string str() const { ostringstream ss; ss << this; return ss.str(); } Cursor loc; @@ -141,7 +172,7 @@ struct ALiteral : public AST { const ALiteral<VT>* r = dynamic_cast<const ALiteral<VT>*>(&rhs); return (r && (val == r->val)); } - void constrain(TEnv& tenv) const; + void constrain(TEnv& tenv, Constraints& c); CValue compile(CEnv& cenv); const VT val; }; @@ -149,13 +180,16 @@ struct ALiteral : public AST { /// Symbol, e.g. "a" struct ASymbol : public AST { bool operator==(const AST& rhs) const { return this == &rhs; } - void lift(CEnv& cenv); + void lookup(TEnv& tenv); + void constrain(TEnv& tenv, Constraints& c); CValue compile(CEnv& cenv); + LAddr addr; private: friend class PEnv; ASymbol(const string& s, Cursor c) : AST(c), cppstr(s) {} friend ostream& operator<<(ostream&, const AST*); const string cppstr; + }; /// Tuple (heterogeneous sequence of fixed length), e.g. "(a b c)" @@ -163,6 +197,8 @@ struct ATuple : public AST, public vector<AST*> { ATuple(const vector<AST*>& t=vector<AST*>(), Cursor c=Cursor()) : AST(c), vector<AST*>(t) {} ATuple(size_t size, Cursor c) : AST(c), vector<AST*>(size) {} ATuple(Cursor c, AST* ast, ...) : AST(c) { + if (!ast) + return; va_list args; va_start(args, ast); push_back(ast); for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*)) @@ -189,23 +225,23 @@ struct ATuple : public AST, public vector<AST*> { return true; return false; } - void constrain(TEnv& tenv) const; + void constrain(TEnv& tenv, Constraints& c); CValue compile(CEnv& cenv) { throw Error("tuple compiled"); } }; /// Type Expression, e.g. "Int", "(Fn (Int Int) Float)" struct AType : public ATuple { - AType(unsigned i, Cursor c=Cursor()) : ATuple(0, c), kind(VAR), id(i) {} + AType(unsigned i, LAddr a, Cursor c=Cursor()) : ATuple(0, c), kind(VAR), addr(a), id(i) {} AType(ASymbol* s) : ATuple(0, s->loc), kind(PRIM), id(0) { push_back(s); } AType(const ATuple& t, Cursor c) : ATuple(t, c), kind(EXPR), id(0) {} AType(Cursor c, AST* ast, ...) : ATuple(0, c), kind(EXPR), id(0) { va_list args; va_start(args, ast); + if (!ast) return; push_back(ast); for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*)) push_back(a); va_end(args); } - void constrain(TEnv& tenv) const {} CValue compile(CEnv& cenv) { return NULL; } bool var() const { return kind == VAR; } bool concrete() const { @@ -234,6 +270,7 @@ struct AType : public ATuple { return false; // never reached } enum { VAR, PRIM, EXPR } kind; + LAddr addr; unsigned id; }; @@ -249,22 +286,25 @@ struct Funcs : public list< pair<AType*, CFunction> > { /// Closure (first-class function with captured lexical bindings) struct AClosure : public ATuple { - AClosure(Cursor c, ASymbol* fn, ATuple* p, AST* b, const string& n="") - : ATuple(c, fn, p, b, NULL), name(n) {} + AClosure(Cursor c, ASymbol* fn, ATuple* p, const string& n="") + : ATuple(c, fn, p, NULL), type(0), subst(0), name(n) {} bool operator==(const AST& rhs) const { return this == &rhs; } - void constrain(TEnv& tenv) const; + void constrain(TEnv& tenv, Constraints& c); void lift(CEnv& cenv); + void liftPoly(CEnv& cenv, const vector<AType*>& argsT); CValue compile(CEnv& cenv); ATuple* prot() const { return dynamic_cast<ATuple*>(at(1)); } -private: + AType* type; Funcs funcs; + Subst* subst; +private: string name; }; /// Function call/application, e.g. "(func arg1 arg2)" struct ACall : public ATuple { ACall(const SExp& e, const ATuple& t) : ATuple(t, e.loc) {} - void constrain(TEnv& tenv) const; + void constrain(TEnv& tenv, Constraints& c); void lift(CEnv& cenv); CValue compile(CEnv& cenv); }; @@ -272,7 +312,8 @@ struct ACall : public ATuple { /// Definition special form, e.g. "(def x 2)" struct ADefinition : public ACall { ADefinition(const SExp& e, const ATuple& t) : ACall(e, t) {} - void constrain(TEnv& tenv) const; + ASymbol* sym() const { return dynamic_cast<ASymbol*>(at(1)); } + void constrain(TEnv& tenv, Constraints& c); void lift(CEnv& cenv); CValue compile(CEnv& cenv); }; @@ -280,14 +321,14 @@ struct ADefinition : public ACall { /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct AIf : public ACall { AIf(const SExp& e, const ATuple& t) : ACall(e, t) {} - void constrain(TEnv& tenv) const; + void constrain(TEnv& tenv, Constraints& c); CValue compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct APrimitive : public ACall { APrimitive(const SExp& e, const ATuple& t) : ACall(e, t) {} - void constrain(TEnv& tenv) const; + void constrain(TEnv& tenv, Constraints& c); CValue compile(CEnv& cenv); }; @@ -295,7 +336,7 @@ struct APrimitive : public ACall { struct AConsCall : public ACall { AConsCall(const SExp& e, const ATuple& t) : ACall(e, t) {} AType* functionType(CEnv& cenv); - void constrain(TEnv& tenv) const; + void constrain(TEnv& tenv, Constraints& c); void lift(CEnv& cenv); CValue compile(CEnv& cenv); static Funcs funcs; @@ -304,14 +345,14 @@ struct AConsCall : public ACall { /// Car special form, e.g. "(car p)" struct ACarCall : public ACall { ACarCall(const SExp& e, const ATuple& t) : ACall(e, t) {} - void constrain(TEnv& tenv) const; + void constrain(TEnv& tenv, Constraints& c); CValue compile(CEnv& cenv); }; /// Cdr special form, e.g. "(cdr p)" struct ACdrCall : public ACall { ACdrCall(const SExp& e, const ATuple& t) : ACall(e, t) {} - void constrain(TEnv& tenv) const; + void constrain(TEnv& tenv, Constraints& c); CValue compile(CEnv& cenv); }; @@ -375,33 +416,70 @@ struct PEnv : private map<const string, ASymbol*> { * Typing * ***************************************************************************/ +struct Constraint : public pair<AType*,AType*> { + Constraint(AType* a, AType* b, Cursor c) : pair<AType*,AType*>(a, b), loc(c) {} + Cursor loc; +}; + +struct Constraints : public list<Constraint> { + void constrain(TEnv& tenv, const AST* o, AType* t); +}; + +inline ostream& operator<<(ostream& out, const Constraints& c) { + for (Constraints::const_iterator i = c.begin(); i != c.end(); ++i) + out << i->first << " : " << i->second << endl; + return out; +} + +struct Subst : public map<const AType*,AType*> { + Subst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); } + static Subst compose(const Subst& delta, const Subst& gamma); + AST* apply(AST* ast) const { + AType* in = dynamic_cast<AType*>(ast); + if (!in) return ast; + if (in->kind == AType::EXPR) { + AType* out = new AType(in->loc, NULL); + for (size_t i = 0; i < in->size(); ++i) + out->push_back(apply(in->at(i))); + return out; + } else { + const_iterator i; + while ((i = find(in)) != end()) + in = i->second; + return in; + } + } +}; + /// Type-Time Environment struct TEnv : public Env<const AST*,AType*> { TEnv(PEnv& p) : penv(p), varID(1) {} - struct Constraint : public pair<AType*,AType*> { - Constraint(AType* a, AType* b, Cursor c) : pair<AType*,AType*>(a, b), loc(c) {} - Cursor loc; - }; - struct Subst : public map<AType*, AType*> { - Subst(AType* s=0, AType* t=0) { if (s && t) insert(make_pair(s, t)); } - }; - typedef list<Constraint> Constraints; - AType* var(Cursor c=Cursor()) { return new AType(varID++, c); } - AType* type(const AST* ast) { - AType** t = ref(ast); - return t ? *t : def(ast, var(ast->loc)); + + AType* fresh(const ASymbol* sym) { + assert(sym); + return def(sym, new AType(varID++, LAddr(), sym->loc)); + } + AType* var(const AST* ast=0) { + const ASymbol* sym = dynamic_cast<const ASymbol*>(ast); + if (sym) + return deref(lookup(sym)); + + map<const AST*, AType*>::iterator v = vars.find(ast); + if (v != vars.end()) + return v->second; + + AType* ret = new AType(varID++, LAddr(), ast ? ast->loc : Cursor()); + if (ast) + vars[ast] = ret; + + return ret; } AType* named(const string& name) { return *ref(penv.sym(name)); } - void constrain(const AST* o, AType* t) { - assert(!dynamic_cast<const AType*>(o)); - constraints.push_back(Constraint(type(o), t, o->loc)); - } - void solve() { apply(unify(constraints)); } - void apply(const Subst& substs); static Subst unify(const Constraints& c); + map<const AST*, AType*> vars; PEnv& penv; Constraints constraints; unsigned varID; @@ -422,21 +500,29 @@ struct CEnv { CEngine engine(); string gensym(const char* s="_") { return (format("%s%d") % s % symID++).str(); } - void push() { code.push(); vals.push(); } - void pop() { code.pop(); vals.pop(); } + void push() { code.push(); vals.push(); tenv.push(); } + void pop() { code.pop(); vals.pop(); tenv.pop(); } void precompile(AST* obj, CValue value) { vals.def(obj, value); } CValue compile(AST* obj); void optimise(CFunction f); void write(std::ostream& os); + AType* type(AST* ast, const Subst& subst = Subst()) const { + ASymbol* sym = dynamic_cast<ASymbol*>(ast); + if (sym) + return tenv.deref(sym->addr); + return dynamic_cast<AType*>(tsubst.apply(subst.apply(tenv.vars[ast]))); + } - ostream& out; - ostream& err; - PEnv& penv; - TEnv& tenv; - Code code; - Vals vals; - unsigned symID; - CFunction alloc; + ostream& out; + ostream& err; + PEnv& penv; + TEnv& tenv; + Code code; + Vals vals; + + unsigned symID; + CFunction alloc; + Subst tsubst; private: struct PImpl; ///< Private Implementation @@ -15,78 +15,157 @@ * along with Tuplr. If not, see <http://www.gnu.org/licenses/>. */ +#include <set> #include "tuplr.hpp" +void +Constraints::constrain(TEnv& tenv, const AST* o, AType* t) +{ + assert(!dynamic_cast<const AType*>(o)); + push_back(Constraint(tenv.var(o), t, o->loc)); +} + /*************************************************************************** * AST Type Constraints * ***************************************************************************/ void -ATuple::constrain(TEnv& tenv) const +ASymbol::lookup(TEnv& tenv) +{ + addr = tenv.lookup(this); + if (!addr) + throw Error((format("undefined symbol `%1%'") % cppstr).str(), loc); +} + +void +ASymbol::constrain(TEnv& tenv, Constraints& c) +{ + lookup(tenv); + AType* t = tenv.deref(addr); + if (!t) + throw Error((format("unresolved symbol `%1%'") % cppstr).str(), loc); + c.push_back(Constraint(tenv.var(this), tenv.deref(addr), loc)); +} + +void +ATuple::constrain(TEnv& tenv, Constraints& c) { AType* t = new AType(ATuple(), loc); - FOREACH(const_iterator, p, *this) { - (*p)->constrain(tenv); - t->push_back(tenv.type(*p)); + FOREACH(iterator, p, *this) { + (*p)->constrain(tenv, c); + t->push_back(tenv.var(*p)); } - tenv.constrain(this, t); + c.push_back(Constraint(tenv.var(this), t, loc)); } void -AClosure::constrain(TEnv& tenv) const +AClosure::constrain(TEnv& tenv, Constraints& c) { - at(1)->constrain(tenv); - at(2)->constrain(tenv); - AType* protT = tenv.type(at(1)); - AType* bodyT = tenv.type(at(2)); - tenv.constrain(this, new AType(loc, tenv.penv.sym("Fn"), protT, bodyT, 0)); + set<ASymbol*> defined; + TEnv::Frame frame; + + // Add parameters to environment frame + for (size_t i = 0; i < prot()->size(); ++i) { + ASymbol* sym = dynamic_cast<ASymbol*>(prot()->at(i)); + if (!sym) + throw Error("parameter name is not a symbol", prot()->at(i)->loc); + if (defined.find(sym) != defined.end()) + throw Error((format("duplicate parameter `%1%'") % sym->str()).str(), sym->loc); + defined.insert(sym); + frame.push_back(make_pair(sym, (AType*)NULL)); + } + + // Add internal definitions to environment frame + size_t e = 2; + for (; e < size(); ++e) { + AST* exp = at(e); + ADefinition* def = dynamic_cast<ADefinition*>(exp); + if (def) { + ASymbol* sym = def->sym(); + if (defined.find(sym) != defined.end()) + throw Error((format("`%1%' defined twice") % sym->str()).str(), def->loc); + defined.insert(def->sym()); + frame.push_back(make_pair(def->sym(), (AType*)NULL)); + } + } + + tenv.push(frame); + + Constraints cp; + cp.push_back(Constraint(tenv.var(this), tenv.var(), loc)); + + AType* protT = new AType(ATuple(), loc); + for (size_t i = 0; i < prot()->size(); ++i) { + AType* tvar = tenv.fresh(dynamic_cast<ASymbol*>(prot()->at(i))); + protT->push_back(tvar); + assert(frame[i].first == prot()->at(i)); + frame[i].second = tvar; + } + c.push_back(Constraint(tenv.var(at(1)), protT, at(1)->loc)); + + for (size_t i = 2; i < size(); ++i) + at(i)->constrain(tenv, cp); + + AType* bodyT = tenv.var(at(e-1)); + Subst tsubst = TEnv::unify(cp); + type = new AType(loc, tenv.penv.sym("Fn"), tsubst.apply(protT), tsubst.apply(bodyT), 0); + + tenv.pop(); + + c.constrain(tenv, this, type); + subst = new Subst(tsubst); } void -ACall::constrain(TEnv& tenv) const +ACall::constrain(TEnv& tenv, Constraints& c) { - FOREACH(const_iterator, p, *this) - (*p)->constrain(tenv); - AType* retT = tenv.type(this); AType* argsT = new AType(ATuple(), loc); - for (size_t i = 1; i < size(); ++i) - argsT->push_back(tenv.type(at(i))); - tenv.constrain(at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0)); + TEnv::Frame frame; + for (size_t i = 1; i < size(); ++i) { + at(i)->constrain(tenv, c); + argsT->push_back(tenv.var(at(i))); + frame.push_back(make_pair((AST*)NULL, tenv.var(at(i)))); + } + AType* retT = tenv.var(); + + at(0)->constrain(tenv, c); + + c.constrain(tenv, at(0), new AType(at(0)->loc, tenv.penv.sym("Fn"), argsT, retT, 0)); + c.constrain(tenv, this, retT); } void -ADefinition::constrain(TEnv& tenv) const +ADefinition::constrain(TEnv& tenv, Constraints& c) { if (size() != 3) throw Error("`def' requires exactly 2 arguments", loc); if (!dynamic_cast<const ASymbol*>(at(1))) throw Error("`def' name is not a symbol", loc); - FOREACH(const_iterator, p, *this) - (*p)->constrain(tenv); - AType* tvar = tenv.type(this); - tenv.constrain(at(1), tvar); - tenv.constrain(at(2), tvar); + AType* tvar = tenv.var(this); + tenv.def(at(1), tvar); + at(2)->constrain(tenv, c); + c.constrain(tenv, at(2), tvar); } void -AIf::constrain(TEnv& tenv) const +AIf::constrain(TEnv& tenv, Constraints& c) { if (size() < 3) throw Error("`if' requires exactly 3 arguments", loc); if (size() % 2 != 0) throw Error("`if' missing final else clause", loc); - FOREACH(const_iterator, p, *this) - (*p)->constrain(tenv); - AType* retT = tenv.type(this); + for (size_t i = 1; i < size(); ++i) + at(i)->constrain(tenv, c); + AType* retT = tenv.var(this); for (size_t i = 1; i < size(); i += 2) { if (i == size() - 1) { - tenv.constrain(at(i), retT); + c.constrain(tenv, at(i), retT); } else { - tenv.constrain(at(i), tenv.named("Bool")); - tenv.constrain(at(i+1), retT); + c.constrain(tenv, at(i), tenv.named("Bool")); + c.constrain(tenv, at(i+1), retT); } } } void -APrimitive::constrain(TEnv& tenv) const +APrimitive::constrain(TEnv& tenv, Constraints& c) { const string n = dynamic_cast<ASymbol*>(at(0))->str(); enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type; @@ -101,34 +180,34 @@ APrimitive::constrain(TEnv& tenv) const else throw Error((format("unknown primitive `%1%'") % n).str(), loc); - FOREACH(const_iterator, p, *this) - (*p)->constrain(tenv); + for (size_t i = 1; i < size(); ++i) + at(i)->constrain(tenv, c); switch (type) { case ARITHMETIC: if (size() < 3) throw Error((format("`%1%' requires at least 2 arguments") % n).str(), loc); for (size_t i = 1; i < size(); ++i) - tenv.constrain(at(i), tenv.type(this)); + c.constrain(tenv, at(i), tenv.var(this)); break; case BINARY: if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc); - tenv.constrain(at(1), tenv.type(this)); - tenv.constrain(at(2), tenv.type(this)); + c.constrain(tenv, at(1), tenv.var(this)); + c.constrain(tenv, at(2), tenv.var(this)); break; case LOGICAL: if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc); - tenv.constrain(this, tenv.named("Bool")); - tenv.constrain(at(1), tenv.named("Bool")); - tenv.constrain(at(2), tenv.named("Bool")); + c.constrain(tenv, this, tenv.named("Bool")); + c.constrain(tenv, at(1), tenv.named("Bool")); + c.constrain(tenv, at(2), tenv.named("Bool")); break; case COMPARISON: if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), loc); - tenv.constrain(this, tenv.named("Bool")); - tenv.constrain(at(1), tenv.type(at(2))); + c.constrain(tenv, this, tenv.named("Bool")); + c.constrain(tenv, at(1), tenv.var(at(2))); break; default: throw Error((format("unknown primitive `%1%'") % n).str(), loc); @@ -136,37 +215,37 @@ APrimitive::constrain(TEnv& tenv) const } void -AConsCall::constrain(TEnv& tenv) const +AConsCall::constrain(TEnv& tenv, Constraints& c) { if (size() != 3) throw Error("`cons' requires exactly 2 arguments", loc); AType* t = new AType(loc, tenv.penv.sym("Pair"), 0); for (size_t i = 1; i < size(); ++i) { - at(i)->constrain(tenv); - t->push_back(tenv.type(at(i))); + at(i)->constrain(tenv, c); + t->push_back(tenv.var(at(i))); } - tenv.constrain(this, t); + c.constrain(tenv, this, t); } void -ACarCall::constrain(TEnv& tenv) const +ACarCall::constrain(TEnv& tenv, Constraints& c) { if (size() != 2) throw Error("`car' requires exactly 1 argument", loc); - at(1)->constrain(tenv); - AType* carT = tenv.type(this); + at(1)->constrain(tenv, c); + AType* carT = tenv.var(this); AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), carT, tenv.var(), 0); - tenv.constrain(at(1), pairT); - tenv.constrain(this, carT); + c.constrain(tenv, at(1), pairT); + c.constrain(tenv, this, carT); } void -ACdrCall::constrain(TEnv& tenv) const +ACdrCall::constrain(TEnv& tenv, Constraints& c) { if (size() != 2) throw Error("`cdr' requires exactly 1 argument", loc); - at(1)->constrain(tenv); - AType* cdrT = tenv.type(this); + at(1)->constrain(tenv, c); + AType* cdrT = tenv.var(this); AType* pairT = new AType(at(1)->loc, tenv.penv.sym("Pair"), tenv.var(), cdrT, 0); - tenv.constrain(at(1), pairT); - tenv.constrain(this, cdrT); + c.constrain(tenv, at(1), pairT); + c.constrain(tenv, this, cdrT); } @@ -175,25 +254,26 @@ ACdrCall::constrain(TEnv& tenv) const ***************************************************************************/ static void -substitute(ATuple* tup, AST* from, AST* to) +substitute(ATuple* tup, const AST* from, AST* to) { if (!tup) return; for (size_t i = 0; i < tup->size(); ++i) if (*tup->at(i) == *from) tup->at(i) = to; - else + else if (tup->at(i) != to) substitute(dynamic_cast<ATuple*>(tup->at(i)), from, to); } -TEnv::Subst -compose(const TEnv::Subst& delta, const TEnv::Subst& gamma) // TAPL 22.1.1 + +Subst +Subst::compose(const Subst& delta, const Subst& gamma) // TAPL 22.1.1 { - TEnv::Subst r; - for (TEnv::Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) { - TEnv::Subst::const_iterator d = delta.find(g->second); + Subst r; + for (Subst::const_iterator g = gamma.begin(); g != gamma.end(); ++g) { + Subst::const_iterator d = delta.find(g->second); r.insert(make_pair(g->first, ((d != delta.end()) ? d : g)->second)); } - for (TEnv::Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) { + for (Subst::const_iterator d = delta.begin(); d != delta.end(); ++d) { if (gamma.find(d->first) == gamma.end()) r.insert(*d); } @@ -201,10 +281,10 @@ compose(const TEnv::Subst& delta, const TEnv::Subst& gamma) // TAPL 22.1.1 } void -substConstraints(TEnv::Constraints& constraints, AType* s, AType* t) +substConstraints(Constraints& constraints, AType* s, AType* t) { - for (TEnv::Constraints::iterator c = constraints.begin(); c != constraints.end();) { - TEnv::Constraints::iterator next = c; ++next; + for (Constraints::iterator c = constraints.begin(); c != constraints.end();) { + Constraints::iterator next = c; ++next; if (*c->first == *s) c->first = t; if (*c->second == *s) c->second = t; substitute(c->first, s, t); @@ -213,7 +293,7 @@ substConstraints(TEnv::Constraints& constraints, AType* s, AType* t) } } -TEnv::Subst +Subst TEnv::unify(const Constraints& constraints) // TAPL 22.4 { if (constraints.empty()) return Subst(); @@ -226,10 +306,10 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4 return unify(cp); } else if (s->var() && !t->contains(s)) { substConstraints(cp, s, t); - return compose(unify(cp), Subst(s, t)); + return Subst::compose(unify(cp), Subst(s, t)); } else if (t->var() && !s->contains(t)) { substConstraints(cp, t, s); - return compose(unify(cp), Subst(t, s)); + return Subst::compose(unify(cp), Subst(t, s)); } else if (s->kind == AType::EXPR && s->kind == t->kind && s->size() == t->size()) { for (size_t i = 0; i < s->size(); ++i) { AType* si = dynamic_cast<AType*>(s->at(i)); @@ -244,14 +324,3 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4 } } -void -TEnv::apply(const TEnv::Subst& substs) -{ - FOREACH(Subst::const_iterator, s, substs) - FOREACH(Frame::iterator, t, front()) - if (*t->second == *s->first) - t->second = s->second; - else - substitute(t->second, s->first, s->second); -} - |