aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Robillard <d@drobilla.net>2010-12-29 22:50:20 +0000
committerDavid Robillard <d@drobilla.net>2010-12-29 22:50:20 +0000
commit2eda11ca28589991471ff3251cccc2471424770e (patch)
treea763edd57dcdcabba3f5eccf5af23f52d90b1eca
parent0076c050fb12c92a35b673d63fca82d5cff63bdb (diff)
downloadresp-2eda11ca28589991471ff3251cccc2471424770e.tar.gz
resp-2eda11ca28589991471ff3251cccc2471424770e.tar.bz2
resp-2eda11ca28589991471ff3251cccc2471424770e.zip
Destructuring (i.e. working `match').
git-svn-id: http://svn.drobilla.net/resp/resp@374 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r--src/c.cpp7
-rw-r--r--src/compile.cpp25
-rw-r--r--src/constrain.cpp12
-rw-r--r--src/lift.cpp2
-rw-r--r--src/llvm.cpp7
-rw-r--r--src/resp.hpp5
-rw-r--r--src/simplify.cpp48
-rwxr-xr-xtest.sh2
-rw-r--r--test/match.resp14
9 files changed, 100 insertions, 22 deletions
diff --git a/src/c.cpp b/src/c.cpp
index 7d5b125..ff4f09a 100644
--- a/src/c.cpp
+++ b/src/c.cpp
@@ -47,6 +47,7 @@ struct CEngine : public Engine {
void eraseFn(CEnv& cenv, CFunc f);
CVal compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector<CVal>& args);
+ CVal compileCast(CEnv& cenv, CVal v, const AST* t);
CVal compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector<CVal>& fields);
CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t);
@@ -139,6 +140,12 @@ CEngine::compileCall(CEnv& cenv, CFunc func, const ATuple* funcT, const vector<C
}
CVal
+CEngine::compileCast(CEnv& cenv, CVal v, const AST* t)
+{
+ return v;
+}
+
+CVal
CEngine::compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector<CVal>& fields)
{
return NULL;
diff --git a/src/compile.cpp b/src/compile.cpp
index d2f0530..90e31ba 100644
--- a/src/compile.cpp
+++ b/src/compile.cpp
@@ -51,6 +51,14 @@ compile_literal_symbol(CEnv& cenv, const ASymbol* sym) throw()
}
static CVal
+compile_cast(CEnv& cenv, const ATuple* cast) throw()
+{
+ return cenv.engine()->compileCast(cenv,
+ resp_compile(cenv, cast->frst()),
+ cenv.type(cast));
+}
+
+static CVal
compile_type(CEnv& cenv, const AST* type) throw()
{
return compile_literal_symbol(cenv, type->as_tuple()->fst()->as_symbol());
@@ -78,6 +86,7 @@ compile_dot(CEnv& cenv, const ATuple* dot) throw()
const ALiteral<int32_t>* index = (ALiteral<int32_t>*)(*++i);
assert(index->tag() == T_INT32);
CVal tupVal = resp_compile(cenv, tup);
+ assert((unsigned)index->val < cenv.type(dot->frst())->as_tuple()->list_len());
return cenv.engine()->compileDot(cenv, tupVal, index->val);
}
@@ -175,14 +184,24 @@ compile_quote(CEnv& cenv, const ATuple* quote) throw()
static CVal
compile_call(CEnv& cenv, const ATuple* call) throw()
{
+ const ATuple* protT = cenv.type(call->fst())->as_tuple()->prot();
CFunc f = resp_compile(cenv, call->fst());
if (!f)
f = cenv.currentFn; // Recursive call (callee defined as a stub)
vector<CVal> args;
- for (ATuple::const_iterator e = call->iter_at(1); e != call->end(); ++e)
- args.push_back(resp_compile(cenv, *e));
+ ATuple::const_iterator t = protT->iter_at(0);
+ for (ATuple::const_iterator e = call->iter_at(1); e != call->end(); ++e, ++t) {
+ CVal arg = resp_compile(cenv, *e);
+ if ((*e)->to_symbol()) {
+ if (cenv.type(*e) != cenv.type(*t)) {
+ args.push_back(cenv.engine()->compileCast(cenv, arg, *t));
+ continue;
+ }
+ }
+ args.push_back(arg);
+ }
return cenv.engine()->compileCall(cenv, f, cenv.type(call->fst())->as_tuple(), args);
}
@@ -211,6 +230,8 @@ resp_compile(CEnv& cenv, const AST* ast) throw()
const std::string form = sym ? sym->sym() : "";
if (is_primitive(cenv.penv, call))
return cenv.engine()->compilePrimitive(cenv, ast->as_tuple());
+ else if (form == "cast")
+ return compile_cast(cenv, call);
else if (form == "cons" || isupper(form[0]))
return compile_cons(cenv, call);
else if (form == ".")
diff --git a/src/constrain.cpp b/src/constrain.cpp
index a4a1ad3..78accf6 100644
--- a/src/constrain.cpp
+++ b/src/constrain.cpp
@@ -241,6 +241,8 @@ constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
THROW_IF(!pattern, exp->loc, "pattern expression expected");
const ASymbol* name = (*pattern->begin())->to_symbol();
THROW_IF(!name, (*pattern->begin())->loc, "pattern does not start with a symbol");
+ THROW_IF(!tenv.ref(name), name->loc,
+ (format("undefined constructor `%1%'") % name->sym()).str());
const AST* consT = *tenv.ref(name);
@@ -251,8 +253,16 @@ constrain_match(TEnv& tenv, Constraints& c, const ATuple* call) throw(Error)
THROW_IF(i == call->end(), pattern->loc, "missing pattern body");
const AST* body = *i++;
+
+ TEnv::Frame frame;
+ ATuple::const_iterator ti = consT->as_tuple()->iter_at(2);
+ for (ATuple::const_iterator pi = pattern->iter_at(1); pi != pattern->end(); ++pi)
+ frame.push_back(make_pair((*pi)->as_symbol()->sym(), *ti++));
+
+ tenv.push(frame);
resp_constrain(tenv, c, body);
- c.constrain(tenv, body, retT);
+ c.constrain(tenv, body, retT);
+ tenv.pop();
}
c.constrain(tenv, call, retT);
c.constrain(tenv, matchee, matcheeT);
diff --git a/src/lift.cpp b/src/lift.cpp
index 8a6c4b0..3e785b9 100644
--- a/src/lift.cpp
+++ b/src/lift.cpp
@@ -271,7 +271,7 @@ resp_lift(CEnv& cenv, Code& code, const AST* ast) throw()
return lift_fn(cenv, code, call);
else if (form == "if")
return lift_args(cenv, code, call);
- else if (form == "quote")
+ else if (form == "quote" || form == "cast")
return call;
else
return lift_call(cenv, code, call);
diff --git a/src/llvm.cpp b/src/llvm.cpp
index cc4bf47..28d6f36 100644
--- a/src/llvm.cpp
+++ b/src/llvm.cpp
@@ -62,6 +62,7 @@ struct LLVMEngine : public Engine {
void eraseFn(CEnv& cenv, CFunc f);
CVal compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector<CVal>& args);
+ CVal compileCast(CEnv& cenv, CVal v, const AST* t);
CVal compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector<CVal>& fields);
CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t);
@@ -207,6 +208,12 @@ LLVMEngine::compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector<C
}
CVal
+LLVMEngine::compileCast(CEnv& cenv, CVal v, const AST* t)
+{
+ return builder.CreateBitCast(llVal(v), llType(t), "cast");
+}
+
+CVal
LLVMEngine::compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector<CVal>& fields)
{
// Find size of memory required
diff --git a/src/resp.hpp b/src/resp.hpp
index 57880f5..ba0ae4b 100644
--- a/src/resp.hpp
+++ b/src/resp.hpp
@@ -528,7 +528,8 @@ struct Subst : public list<Constraint> {
List out;
for (ATuple::const_iterator i = in->as_tuple()->begin(); i != in->as_tuple()->end(); ++i)
out.push_back(apply((*i)));
- out.head->loc = in->loc;
+ if (out.head)
+ out.head->loc = in->loc;
return out.head;
} else {
const_iterator i = find(in);
@@ -656,6 +657,7 @@ struct Engine {
virtual void eraseFn(CEnv& cenv, CFunc f) = 0;
virtual CVal compileCall(CEnv& cenv, CFunc f, const ATuple* fT, CVals& args) = 0;
+ virtual CVal compileCast(CEnv& cenv, CVal v, const AST* t) = 0;
virtual CVal compileCons(CEnv& cenv, const ATuple* t, CVal rtti, CVals& f) = 0;
virtual CVal compileDot(CEnv& cenv, CVal tup, int32_t index) = 0;
virtual CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t) = 0;
@@ -726,7 +728,6 @@ struct CEnv {
return rec ? *rec : ast;
}
void setType(const AST* ast, const AST* type) {
- //assert(!ast->to_symbol());
const AST* tvar = tenv.var();
tenv.vars.insert(make_pair(ast, tvar));
tsubst.add(tvar, type);
diff --git a/src/simplify.cpp b/src/simplify.cpp
index c5b2566..d6188ad 100644
--- a/src/simplify.cpp
+++ b/src/simplify.cpp
@@ -56,12 +56,12 @@ simplify_if(CEnv& cenv, const ATuple* aif) throw()
static const AST*
simplify_match(CEnv& cenv, const ATuple* match) throw()
{
+ const AST* const obj = resp_simplify(cenv, match->list_ref(1));
+
// Dot expression to get tag. Note index is -1 to compensate for the lift phase
// which adds 1 to skip the RTTI, which we don't want here (FIXME: ick...)
- List tval;
- tval.push_back(cenv.penv.sym("."));
- tval.push_back(resp_simplify(cenv, match->list_ref(1)));
- tval.push_back(new ALiteral<int32_t>(T_INT32, -1, Cursor()));
+ const AST* index = new ALiteral<int32_t>(T_INT32, -1, Cursor());
+ List tval(Cursor(), cenv.penv.sym("."), obj, index, 0);
const ASymbol* tsym = cenv.penv.gensym("__tag");
@@ -78,11 +78,47 @@ simplify_match(CEnv& cenv, const ATuple* match) throw()
const_cast<ASymbol*>(consTag)->tag(T_LITSYM);
cenv.setType(consTag, cenv.tenv.named("Symbol"));
+ const ATuple* texp = cenv.tenv.named(consTag->sym())->as_tuple();
+
+ // Append condition for this case
List cond(Cursor(), cenv.penv.sym("="), tsym, consTag, 0);
cenv.setType(cond, cenv.tenv.named("Bool"));
-
copyIf.push_back(cond);
- copyIf.push_back(resp_simplify(cenv, body));
+
+ // If constructor has no variables, append body and continue
+ // (don't generate pointless fn)
+ if (texp->list_len() == 2) {
+ copyIf.push_back(body);
+ continue;
+ }
+
+ // Build fn for the body of this case
+ const ASymbol* osym = cenv.penv.gensym("__obj");
+ const ATuple* prot = new ATuple(osym, 0, Cursor());
+ const ATuple* protT = new ATuple(texp->rst(), 0, Cursor());
+
+ List fn(Cursor(), cenv.penv.sym("fn"), prot, 0);
+ int idx = 0;
+ ATuple::const_iterator ti = texp->iter_at(2);
+ for (ATuple::const_iterator j = pat->iter_at(1); j != pat->end(); ++j, ++ti, ++idx) {
+ const AST* index = new ALiteral<int32_t>(T_INT32, idx, Cursor());
+ const AST* dot = tup(Cursor(), cenv.penv.sym("."), osym, index, 0);
+ const AST* def = tup(Cursor(), cenv.penv.sym("def"), *j, dot);
+ fn.push_back(def);
+ }
+
+ fn.push_back(resp_simplify(cenv, body));
+
+ List fnT(Cursor(), cenv.tenv.Fn, protT, cenv.type(match), 0);
+ assert(fnT.head->list_ref(1));
+ cenv.setType(fn, fnT);
+
+ const ATuple* cast = tup(Cursor(), cenv.penv.sym("cast"), obj, 0);
+ cenv.setType(cast, texp->rst());
+
+ List call(Cursor(), fn, cast, 0);
+ cenv.setTypeSameAs(call, match);
+ copyIf.push_back(call);
}
copyIf.push_back(cenv.penv.sym("__unreachable"));
cenv.setTypeSameAs(copyIf, match);
diff --git a/test.sh b/test.sh
index fa01f80..fa66812 100755
--- a/test.sh
+++ b/test.sh
@@ -33,6 +33,6 @@ run './test/let-over-fn.resp' '2 : Int'
run './test/let.resp' '5 : Int'
# Algebraic data types
-run './test/match.resp' '"Hello, rectangle!" : String'
+run './test/match.resp' '12.0000 : Float'
#run './test/poly.resp' '#t : Bool'
diff --git a/test/match.resp b/test/match.resp
index b50257a..0d490ec 100644
--- a/test/match.resp
+++ b/test/match.resp
@@ -1,14 +1,10 @@
(def-type (Shape)
(Circle Float)
- (Square Float)
(Rectangle Float Float))
-(def c1 (Circle 1.0))
-(def c2 (Circle 2.0))
-(def s1 (Square 1.0))
-(def r1 (Rectangle 1.0 1.0))
+(def s1 (Circle 2.0))
+(def s2 (Rectangle 3.0 4.0))
-(match r1
- (Circle r) "Hello, circle!"
- (Square w) "Hello, square!"
- (Rectangle w h) "Hello, rectangle!")
+(match s2
+ (Rectangle w h) (* w h)
+ (Circle r) (* 3.14159 r))