From 2eda11ca28589991471ff3251cccc2471424770e Mon Sep 17 00:00:00 2001 From: David Robillard Date: Wed, 29 Dec 2010 22:50:20 +0000 Subject: Destructuring (i.e. working `match'). git-svn-id: http://svn.drobilla.net/resp/resp@374 ad02d1e2-f140-0410-9f75-f8b11f17cedd --- src/c.cpp | 7 +++++++ src/compile.cpp | 25 +++++++++++++++++++++++-- src/constrain.cpp | 12 +++++++++++- src/lift.cpp | 2 +- src/llvm.cpp | 7 +++++++ src/resp.hpp | 5 +++-- src/simplify.cpp | 48 ++++++++++++++++++++++++++++++++++++++++++------ test.sh | 2 +- test/match.resp | 14 +++++--------- 9 files changed, 100 insertions(+), 22 deletions(-) diff --git a/src/c.cpp b/src/c.cpp index 7d5b125..ff4f09a 100644 --- a/src/c.cpp +++ b/src/c.cpp @@ -47,6 +47,7 @@ struct CEngine : public Engine { void eraseFn(CEnv& cenv, CFunc f); CVal compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector& args); + CVal compileCast(CEnv& cenv, CVal v, const AST* t); CVal compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector& fields); CVal compileDot(CEnv& cenv, CVal tup, int32_t index); CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t); @@ -138,6 +139,12 @@ CEngine::compileCall(CEnv& cenv, CFunc func, const ATuple* funcT, const vector& fields) { diff --git a/src/compile.cpp b/src/compile.cpp index d2f0530..90e31ba 100644 --- a/src/compile.cpp +++ b/src/compile.cpp @@ -50,6 +50,14 @@ compile_literal_symbol(CEnv& cenv, const ASymbol* sym) throw() } +static CVal +compile_cast(CEnv& cenv, const ATuple* cast) throw() +{ + return cenv.engine()->compileCast(cenv, + resp_compile(cenv, cast->frst()), + cenv.type(cast)); +} + static CVal compile_type(CEnv& cenv, const AST* type) throw() { @@ -78,6 +86,7 @@ compile_dot(CEnv& cenv, const ATuple* dot) throw() const ALiteral* index = (ALiteral*)(*++i); assert(index->tag() == T_INT32); CVal tupVal = resp_compile(cenv, tup); + assert((unsigned)index->val < cenv.type(dot->frst())->as_tuple()->list_len()); return cenv.engine()->compileDot(cenv, tupVal, index->val); } @@ -175,14 +184,24 @@ compile_quote(CEnv& cenv, const ATuple* quote) throw() static CVal compile_call(CEnv& cenv, const ATuple* call) throw() { + const ATuple* protT = cenv.type(call->fst())->as_tuple()->prot(); CFunc f = resp_compile(cenv, call->fst()); if (!f) f = cenv.currentFn; // Recursive call (callee defined as a stub) vector args; - for (ATuple::const_iterator e = call->iter_at(1); e != call->end(); ++e) - args.push_back(resp_compile(cenv, *e)); + ATuple::const_iterator t = protT->iter_at(0); + for (ATuple::const_iterator e = call->iter_at(1); e != call->end(); ++e, ++t) { + CVal arg = resp_compile(cenv, *e); + if ((*e)->to_symbol()) { + if (cenv.type(*e) != cenv.type(*t)) { + args.push_back(cenv.engine()->compileCast(cenv, arg, *t)); + continue; + } + } + args.push_back(arg); + } return cenv.engine()->compileCall(cenv, f, cenv.type(call->fst())->as_tuple(), args); } @@ -211,6 +230,8 @@ resp_compile(CEnv& cenv, const AST* ast) throw() const std::string form = sym ? sym->sym() : ""; if (is_primitive(cenv.penv, call)) return cenv.engine()->compilePrimitive(cenv, ast->as_tuple()); + else if (form == "cast") + return compile_cast(cenv, call); else if (form == "cons" || isupper(form[0])) return compile_cons(cenv, call); else if (form == ".") diff --git a/src/constrain.cpp b/src/constrain.cpp index a4a1ad3..78accf6 100644 --- a/src/constrain.cpp +++ b/src/constrain.cpp @@ -241,6 +241,8 @@ constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) THROW_IF(!pattern, exp->loc, "pattern expression expected"); const ASymbol* name = (*pattern->begin())->to_symbol(); THROW_IF(!name, (*pattern->begin())->loc, "pattern does not start with a symbol"); + THROW_IF(!tenv.ref(name), name->loc, + (format("undefined constructor `%1%'") % name->sym()).str()); const AST* consT = *tenv.ref(name); @@ -251,8 +253,16 @@ constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error) THROW_IF(i == call->end(), pattern->loc, "missing pattern body"); const AST* body = *i++; + + TEnv::Frame frame; + ATuple::const_iterator ti = consT->as_tuple()->iter_at(2); + for (ATuple::const_iterator pi = pattern->iter_at(1); pi != pattern->end(); ++pi) + frame.push_back(make_pair((*pi)->as_symbol()->sym(), *ti++)); + + tenv.push(frame); resp_constrain(tenv, c, body); - c.constrain(tenv, body, retT); + c.constrain(tenv, body, retT); + tenv.pop(); } c.constrain(tenv, call, retT); c.constrain(tenv, matchee, matcheeT); diff --git a/src/lift.cpp b/src/lift.cpp index 8a6c4b0..3e785b9 100644 --- a/src/lift.cpp +++ b/src/lift.cpp @@ -271,7 +271,7 @@ resp_lift(CEnv& cenv, Code& code, const AST* ast) throw() return lift_fn(cenv, code, call); else if (form == "if") return lift_args(cenv, code, call); - else if (form == "quote") + else if (form == "quote" || form == "cast") return call; else return lift_call(cenv, code, call); diff --git a/src/llvm.cpp b/src/llvm.cpp index cc4bf47..28d6f36 100644 --- a/src/llvm.cpp +++ b/src/llvm.cpp @@ -62,6 +62,7 @@ struct LLVMEngine : public Engine { void eraseFn(CEnv& cenv, CFunc f); CVal compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector& args); + CVal compileCast(CEnv& cenv, CVal v, const AST* t); CVal compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector& fields); CVal compileDot(CEnv& cenv, CVal tup, int32_t index); CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t); @@ -206,6 +207,12 @@ LLVMEngine::compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector& fields) { diff --git a/src/resp.hpp b/src/resp.hpp index 57880f5..ba0ae4b 100644 --- a/src/resp.hpp +++ b/src/resp.hpp @@ -528,7 +528,8 @@ struct Subst : public list { List out; for (ATuple::const_iterator i = in->as_tuple()->begin(); i != in->as_tuple()->end(); ++i) out.push_back(apply((*i))); - out.head->loc = in->loc; + if (out.head) + out.head->loc = in->loc; return out.head; } else { const_iterator i = find(in); @@ -656,6 +657,7 @@ struct Engine { virtual void eraseFn(CEnv& cenv, CFunc f) = 0; virtual CVal compileCall(CEnv& cenv, CFunc f, const ATuple* fT, CVals& args) = 0; + virtual CVal compileCast(CEnv& cenv, CVal v, const AST* t) = 0; virtual CVal compileCons(CEnv& cenv, const ATuple* t, CVal rtti, CVals& f) = 0; virtual CVal compileDot(CEnv& cenv, CVal tup, int32_t index) = 0; virtual CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t) = 0; @@ -726,7 +728,6 @@ struct CEnv { return rec ? *rec : ast; } void setType(const AST* ast, const AST* type) { - //assert(!ast->to_symbol()); const AST* tvar = tenv.var(); tenv.vars.insert(make_pair(ast, tvar)); tsubst.add(tvar, type); diff --git a/src/simplify.cpp b/src/simplify.cpp index c5b2566..d6188ad 100644 --- a/src/simplify.cpp +++ b/src/simplify.cpp @@ -56,12 +56,12 @@ simplify_if(CEnv& cenv, const ATuple* aif) throw() static const AST* simplify_match(CEnv& cenv, const ATuple* match) throw() { + const AST* const obj = resp_simplify(cenv, match->list_ref(1)); + // Dot expression to get tag. Note index is -1 to compensate for the lift phase // which adds 1 to skip the RTTI, which we don't want here (FIXME: ick...) - List tval; - tval.push_back(cenv.penv.sym(".")); - tval.push_back(resp_simplify(cenv, match->list_ref(1))); - tval.push_back(new ALiteral(T_INT32, -1, Cursor())); + const AST* index = new ALiteral(T_INT32, -1, Cursor()); + List tval(Cursor(), cenv.penv.sym("."), obj, index, 0); const ASymbol* tsym = cenv.penv.gensym("__tag"); @@ -78,11 +78,47 @@ simplify_match(CEnv& cenv, const ATuple* match) throw() const_cast(consTag)->tag(T_LITSYM); cenv.setType(consTag, cenv.tenv.named("Symbol")); + const ATuple* texp = cenv.tenv.named(consTag->sym())->as_tuple(); + + // Append condition for this case List cond(Cursor(), cenv.penv.sym("="), tsym, consTag, 0); cenv.setType(cond, cenv.tenv.named("Bool")); - copyIf.push_back(cond); - copyIf.push_back(resp_simplify(cenv, body)); + + // If constructor has no variables, append body and continue + // (don't generate pointless fn) + if (texp->list_len() == 2) { + copyIf.push_back(body); + continue; + } + + // Build fn for the body of this case + const ASymbol* osym = cenv.penv.gensym("__obj"); + const ATuple* prot = new ATuple(osym, 0, Cursor()); + const ATuple* protT = new ATuple(texp->rst(), 0, Cursor()); + + List fn(Cursor(), cenv.penv.sym("fn"), prot, 0); + int idx = 0; + ATuple::const_iterator ti = texp->iter_at(2); + for (ATuple::const_iterator j = pat->iter_at(1); j != pat->end(); ++j, ++ti, ++idx) { + const AST* index = new ALiteral(T_INT32, idx, Cursor()); + const AST* dot = tup(Cursor(), cenv.penv.sym("."), osym, index, 0); + const AST* def = tup(Cursor(), cenv.penv.sym("def"), *j, dot); + fn.push_back(def); + } + + fn.push_back(resp_simplify(cenv, body)); + + List fnT(Cursor(), cenv.tenv.Fn, protT, cenv.type(match), 0); + assert(fnT.head->list_ref(1)); + cenv.setType(fn, fnT); + + const ATuple* cast = tup(Cursor(), cenv.penv.sym("cast"), obj, 0); + cenv.setType(cast, texp->rst()); + + List call(Cursor(), fn, cast, 0); + cenv.setTypeSameAs(call, match); + copyIf.push_back(call); } copyIf.push_back(cenv.penv.sym("__unreachable")); cenv.setTypeSameAs(copyIf, match); diff --git a/test.sh b/test.sh index fa01f80..fa66812 100755 --- a/test.sh +++ b/test.sh @@ -33,6 +33,6 @@ run './test/let-over-fn.resp' '2 : Int' run './test/let.resp' '5 : Int' # Algebraic data types -run './test/match.resp' '"Hello, rectangle!" : String' +run './test/match.resp' '12.0000 : Float' #run './test/poly.resp' '#t : Bool' diff --git a/test/match.resp b/test/match.resp index b50257a..0d490ec 100644 --- a/test/match.resp +++ b/test/match.resp @@ -1,14 +1,10 @@ (def-type (Shape) (Circle Float) - (Square Float) (Rectangle Float Float)) -(def c1 (Circle 1.0)) -(def c2 (Circle 2.0)) -(def s1 (Square 1.0)) -(def r1 (Rectangle 1.0 1.0)) +(def s1 (Circle 2.0)) +(def s2 (Rectangle 3.0 4.0)) -(match r1 - (Circle r) "Hello, circle!" - (Square w) "Hello, square!" - (Rectangle w h) "Hello, rectangle!") +(match s2 + (Rectangle w h) (* w h) + (Circle r) (* 3.14159 r)) -- cgit v1.2.1