From ceebe5e8bd7593b99d2a4c8b8fa733a85e0eae81 Mon Sep 17 00:00:00 2001
From: David Robillard <d@drobilla.net>
Date: Sun, 9 Jan 2011 16:47:59 +0000
Subject: Partially resurrect CPS translation pass.

git-svn-id: http://svn.drobilla.net/resp/trunk@405 ad02d1e2-f140-0410-9f75-f8b11f17cedd
---
 src/cps.cpp  | 205 ++++++++++++++++++++++++++++++++++-------------------------
 src/repl.cpp |  70 ++++++++++----------
 src/resp.cpp |   4 +-
 src/resp.hpp |   1 +
 4 files changed, 157 insertions(+), 123 deletions(-)

(limited to 'src')

diff --git a/src/cps.cpp b/src/cps.cpp
index 88ab425..c55f728 100644
--- a/src/cps.cpp
+++ b/src/cps.cpp
@@ -25,123 +25,156 @@
 
 #include "resp.hpp"
 
-/** (cps x cont) => (cont x) */
-static const AST*
-cps_value(TEnv& tenv, AST* cont) const
+static bool
+is_value(CEnv& cenv, const AST* exp)
 {
-	return tup(loc, cont, this, 0);
+	const ATuple* const call = exp->to_tuple();
+	if (!call)
+		return true; // Atom
+
+	if (!is_primitive(cenv.penv, exp))
+		return false; // Non-primitive fn call
+	
+	for (ATuple::const_iterator i = call->iter_at(1); i != call->end(); ++i)
+		if (!is_value(cenv, *i))
+			return false; // Primitive with non-value argument
+
+	return true; // Primitive with all value arguments
 }
 
-/** (cps (fn (a ...) body) cont) => (cont (fn (a ... k) (cps body k)) */
+/** [v]k => (k v) */
 static const AST*
-cps_fn(TEnv& tenv, AST* cont) const
+cps_value(CEnv& cenv, const AST* v, const AST* k)
 {
-	ATuple*  copyProt = new ATuple(*prot());
-	ASymbol* contArg  = tenv.penv.gensym("_k");
-	copyProt->push_back(contArg);
-	AFn* copy = tup(loc, tenv.penv.sym("fn"), copyProt, 0);
-	const_iterator p = begin();
-	++(++p);
-	for (; p != end(); ++p)
-		copy->push_back((*p)->(tenv, contArg));
-	return tup(loc, cont, copy, 0);
+	return tup(v->loc, k, v, 0);
 }
 
+/** [(fn (a ...) r)]k => (k (fn (a ... k2) [r]k2)) */
 static const AST*
-cps_primitive(TEnv& tenv, AST* cont) const
+cps_fn(CEnv& cenv, const ATuple* fn, const AST* cont)
 {
-	return value() ? tup(loc, cont, this, 0) : ATuple::(tenv, cont);
-}
+	const ASymbol* k2 = cenv.penv.gensym("__k");
 
-/** (cps (f a b ...)) => (a (fn (x) (b (fn (y) ... (cont (f x y ...)) */
-static const AST*
-cps_tuple(TEnv& tenv, AST* cont) const
-{
-	std::vector< std::pair<AFn*, AST*> > funcs;
-	AFn*     fn  = NULL;
-	ASymbol* arg = NULL;
-
-	// Make a continuation for each element (operator and arguments)
-	// Argument evaluation continuations are not themselves in CPS.
-	// Each makes a tail call to the next, and the last makes a tail
-	// call to the continuation of this call
-	const_iterator firstFnIter = end();
-	AFn*           firstFn     = NULL;
-	ssize_t        index       = 0;
-	FOREACHP(const_iterator, i, this) {
-		if (!(*i)->to_tuple()) {
-			funcs.push_back(make_pair((AFn*)NULL, (*i)));
-		} else {
-			arg = tenv.penv.gensym("a");
+	List copyProt;
+	FOREACHP(ATuple::const_iterator, i, fn->prot())
+		copyProt.push_back(*i);
+	copyProt.push_back(k2);
 
-			AFn* thisFn = tup(loc, tenv.penv.sym("fn"),
-					tup((*i)->loc, arg, 0),
-					0);
+	assert(fn->fst());
+	assert(copyProt.head);
+	List copy;
+	copy.push_back(cenv.penv.sym("fn"));
+	copy.push_back(copyProt);
 
-			if (firstFnIter == end()) {
-				firstFnIter = i;
-				firstFn = thisFn;
-			}
+	for (ATuple::const_iterator i = fn->iter_at(2); i != fn->end(); ++i)
+		copy.push_back(resp_cps(cenv, *i, k2));
 
-			if (fn)
-				fn->push_back((*i)->(tenv, thisFn));
+	return copy;
+}
 
-			funcs.push_back(make_pair(thisFn, arg));
-			fn = thisFn;
+/** [(f a b ...)]k => [a](fn (__a) [b](fn (__b) ... (f __a __b ... k))) */
+static const AST*
+cps_call(CEnv& cenv, const ATuple* call, const AST* k)
+{
+	// Build innermost application first
+	List body;
+	typedef std::vector<const AST*> ExpVec;
+	ExpVec exprs;
+	ExpVec args;
+	FOREACHP(ATuple::const_iterator, i, call) {
+		exprs.push_back(*i);
+		if (is_value(cenv, *i)) {
+			body.push_back(*i);
+			args.push_back(*i);
+		} else {
+			const ASymbol* sym = cenv.penv.gensym("__a");
+			body.push_back(sym);
+			args.push_back(sym);
 		}
-		++index;
 	}
 
-	if (firstFnIter != end()) {
-		// Call this call's callee in the last argument evaluator
-		ATuple* call = tup(loc, 0);
-		assert(funcs.size() == size());
-		for (size_t i = 0; i < funcs.size(); ++i)
-			call->push_back(funcs[i].second);
-
-		assert(fn);
-		fn->push_back(call->(tenv, cont));
-		return (*firstFnIter)->(tenv, firstFn);
+	const AST* cont;
+	if (cenv.penv.primitives.find(call->fst()->str()) != cenv.penv.primitives.end()) {
+		cont = tup(Cursor(), k, body.head, 0);
 	} else {
-		assert(fst()->value());
-		ATuple* ret = tup(loc, 0);
-		FOREACHP(const_iterator, i, this)
-			ret->push_back((*i));
-		if (!is_primitive(this))
-			ret->push_back(cont);
-		return ret;
+		body.push_back(k);
+		cont = body;
 	}
+
+	// Wrap application in fns to evaluate parameters (from right to left)
+	std::vector<const AST*>::const_reverse_iterator a = args.rbegin();
+	for (ExpVec::const_reverse_iterator e = exprs.rbegin(); e != exprs.rend(); ++e, ++a) {
+		if (!is_value(cenv, *e)) {
+			cont = resp_cps(cenv, *e, tup(Cursor(), cenv.penv.sym("fn"),
+			                              tup(Cursor(), *a, 0),
+			                              cont,
+			                              0));
+		}
+	}
+
+	return cont;
 }
 
-/** (cps (def x y)) => (y (fn (x) (cont))) */
+/** [(def x y)]k => (def x [y]k) */
 static const AST*
-cps_def(TEnv& tenv, AST* cont) const
+cps_def(CEnv& cenv, const ATuple* def, const AST* k)
 {
+	List copy(def->loc, def->fst(), def->frst(), 0);
+	copy.push_back(resp_cps(cenv, def->list_ref(2), k));
+	return copy;
+	/*
 	AST*    val     = body()->(tenv, cont);
 	ATuple* valCall = val->to_tuple();
 	ATuple::iterator i = valCall->begin();
 	return tup(loc, tenv.penv.sym("def"), sym(), *++i, 0);
+	*/
 }
 
-/** (cps (if c t ... e)) => */
+/** [(if c t e)]k => [c](fn (__c) (if c [t]k [e]k)) */
 static const AST*
-cps_iff(TEnv& tenv, AST* cont) const
+cps_if(CEnv& cenv, const ATuple* aif, const AST* k)
 {
-	ASymbol* argSym = tenv.penv.gensym("c");
-	const_iterator i = begin();
-	AST* cond = *++i;
-	AST* exp  = *++i;
-	AST* next = *++i;
-	if (cond->value()) {
-		return tup(loc, tenv.penv.sym("if"), cond,
-			exp->(tenv, cont),
-			next->(tenv, cont), 0);
+	ATuple::const_iterator i = aif->begin();
+	const AST* const c = *++i;
+	const AST* const t = *++i;
+	const AST* const e = *++i;
+	if (is_value(cenv, c)) {
+		return tup(aif->loc, cenv.penv.sym("if"), c,
+		           resp_cps(cenv, t, k),
+		           resp_cps(cenv, e, k), 0);
 	} else {
-		AFn* contFn = tup(loc, tenv.penv.sym("fn"),
-				tup(cond->loc, argSym, tenv.penv.gensym("_k"), 0),
-				tup(loc, tenv.penv.sym("if"), argSym,
-					exp->(tenv, cont),
-					next->(tenv, cont), 0));
-		return cond->(tenv, contFn);
+		/*
+		  const ASymbol* const condSym = cenv.penv.gensym("c");
+		  const ATuple* contFn = tup(loc, tenv.penv.sym("fn"),
+		  tup(cond->loc, argSym, tenv.penv.gensym("_k"), 0),
+		  tup(loc, tenv.penv.sym("if"), argSym,
+		  exp->(tenv, cont),
+		  next->(tenv, cont), 0));
+		  return cond->(tenv, contFn);
+		*/
+		return aif;
 	}
 }
+
+const AST*
+resp_cps(CEnv& cenv, const AST* ast, const AST* k) throw()
+{
+	if (is_value(cenv, ast))
+		return cps_value(cenv, ast, k);
+
+	const ATuple* const call = ast->to_tuple();
+	if (call) {
+		const ASymbol* const sym  = call->fst()->to_symbol();
+		const std::string    form = sym ? sym->sym() : "";
+		if (form == "def")
+			return cps_def(cenv, call, k);
+		else if (form == "fn")
+			return cps_fn(cenv, call, k);
+		else if (form == "if")
+			return cps_if(cenv, call, k);
+		else
+			return cps_call(cenv, call, k);
+	}
+
+	return cps_value(cenv, ast, k);
+}
diff --git a/src/repl.cpp b/src/repl.cpp
index b70618e..c461526 100644
--- a/src/repl.cpp
+++ b/src/repl.cpp
@@ -72,6 +72,14 @@ callPrintCollect(CEnv& cenv, CFunc f, const AST* result, const AST* resultT, boo
 	Object::pool.collect(Object::pool.roots());
 }
 
+static inline int
+dump(CEnv& cenv, const Code& code)
+{
+	for (Code::const_iterator i = code.begin(); i != code.end(); ++i)
+		pprint(cout, *i, &cenv, (cenv.args.find("-a") != cenv.args.end()));
+	return 0;
+}
+
 /// Compile and evaluate code from @a is
 int
 eval(CEnv& cenv, Cursor& cursor, istream& is, bool execute)
@@ -79,51 +87,41 @@ eval(CEnv& cenv, Cursor& cursor, istream& is, bool execute)
 	const AST* exp = NULL;
 	const AST* ast = NULL;
 
-	typedef list<const AST*> Parsed;
-	Parsed parsed;
-
 	try {
+		// Parse and type all expressions
+		Code parsed;
 		while (readParseType(cenv, cursor, is, exp, ast))
 			parsed.push_back(ast);
-
-		if (cenv.args.find("-T") != cenv.args.end()) {
-			for (Parsed::const_iterator i = parsed.begin(); i != parsed.end(); ++i)
-				pprint(cout, *i, &cenv, (cenv.args.find("-a") != cenv.args.end()));
-			return 0;
-		}
+		if (cenv.args.find("-T") != cenv.args.end())
+			return dump(cenv, parsed);
 
 		// Simplify all expressions
 		Code simplified;
-		for (Parsed::const_iterator i = parsed.begin(); i != parsed.end(); ++i) {
-			const AST* l = resp_simplify(cenv, *i);
-			if (l)
-				simplified.push_back(l);
-		}
-
-		if (cenv.args.find("-R") != cenv.args.end()) {
-			for (Code::const_iterator i = simplified.begin(); i != simplified.end(); ++i)
-				pprint(cout, *i, &cenv, (cenv.args.find("-a") != cenv.args.end()));
-			return 0;
-		}
-
-		CVal  val = NULL;
-		CFunc f   = NULL;
-
+		for (Code::const_iterator i = parsed.begin(); i != parsed.end(); ++i)
+			if ((exp = resp_simplify(cenv, *i)))
+				simplified.push_back(exp);
+		if (cenv.args.find("-R") != cenv.args.end())
+			return dump(cenv, simplified);
+
+		// Convert to CPS
+		Code cps;
+		for (Code::const_iterator i = simplified.begin(); i != simplified.end(); ++i)
+			if ((exp = resp_cps(cenv, *i, cenv.penv.sym("display"))))
+				cps.push_back(exp);
+		if (cenv.args.find("-C") != cenv.args.end())
+			return dump(cenv, cps);
+		
 		// Lift all expressions
 		Code lifted;
-		for (Parsed::const_iterator i = simplified.begin(); i != simplified.end(); ++i) {
-			const AST* l = resp_lift(cenv, lifted, *i);
-			if (l)
-				lifted.push_back(l);
-		}
-
-		if (cenv.args.find("-L") != cenv.args.end()) {
-			for (Code::const_iterator i = lifted.begin(); i != lifted.end(); ++i)
-				pprint(cout, *i, &cenv, (cenv.args.find("-a") != cenv.args.end()));
-			return 0;
-		}
-
+		for (Code::const_iterator i = simplified.begin(); i != simplified.end(); ++i)
+			if ((exp = resp_lift(cenv, lifted, *i)))
+				lifted.push_back(exp);
+		if (cenv.args.find("-L") != cenv.args.end())
+			return dump(cenv, lifted);
+		
 		// Compile top-level (lifted) functions
+		CVal  val = NULL;
+		CFunc f   = NULL;
 		Code exprs;
 		for (Code::const_iterator i = lifted.begin(); i != lifted.end(); ++i) {
 			const ATuple* call = (*i)->to_tuple();
diff --git a/src/resp.cpp b/src/resp.cpp
index 2c26927..2b08fc3 100644
--- a/src/resp.cpp
+++ b/src/resp.cpp
@@ -78,6 +78,7 @@ print_usage(char* name, bool error)
 	os << "  -P               Parse only"                                << endl;
 	os << "  -T               Type check only"                           << endl;
 	os << "  -R               Reduce to simpler forms only"              << endl;
+	os << "  -C               Convert to CPS only"                       << endl;
 	os << "  -L               Lambda lift only"                          << endl;
 	os << "  -S               Compile to assembly only (do not execute)" << endl;
 
@@ -96,7 +97,8 @@ main(int argc, char** argv)
 			return print_usage(argv[0], false);
 		} else if (argv[i][0] != '-') {
 			files.push_back(argv[i]);
-		} else if (!strncmp(argv[i], "-L", 3)
+		} else if (!strncmp(argv[i], "-C", 3)
+		           || !strncmp(argv[i], "-L", 3)
 		           || !strncmp(argv[i], "-P", 3)
 		           || !strncmp(argv[i], "-R", 3)
 		           || !strncmp(argv[i], "-S", 3)
diff --git a/src/resp.hpp b/src/resp.hpp
index 7de26d8..41f1826 100644
--- a/src/resp.hpp
+++ b/src/resp.hpp
@@ -834,6 +834,7 @@ int  repl(CEnv& cenv);
 
 void       resp_constrain(TEnv& tenv, Constraints& c, const AST* ast) throw(Error);
 const AST* resp_simplify(CEnv& cenv, const AST* ast) throw();
+const AST* resp_cps(CEnv& cenv, const AST* ast, const AST* k) throw();
 const AST* resp_lift(CEnv& cenv, Code& code, const AST* ast) throw();
 CVal       resp_compile(CEnv& cenv, const AST* ast) throw();
 
-- 
cgit v1.2.1