From 745430d5e71cf268f7fa5111dec03e73e602714f Mon Sep 17 00:00:00 2001
From: David Robillard <d@drobilla.net>
Date: Wed, 28 Jan 2009 03:39:03 +0000
Subject: Compilation of recursive function definitions.

git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@27 ad02d1e2-f140-0410-9f75-f8b11f17cedd
---
 ll.cpp | 86 ++++++++++++++++++++++++++++++++++++++++--------------------------
 1 file changed, 52 insertions(+), 34 deletions(-)

(limited to 'll.cpp')

diff --git a/ll.cpp b/ll.cpp
index 368a3cd..114bab8 100644
--- a/ll.cpp
+++ b/ll.cpp
@@ -271,6 +271,7 @@ struct ASTCall : public ASTTuple {
 struct ASTDefinition : public ASTCall {
 	ASTDefinition(const TupV& t) : ASTCall(t) {}
 	void   constrain(TEnv& tenv) const;
+	void   lift(CEnv& cenv);
 	Value* compile(CEnv& cenv);
 };
 
@@ -378,7 +379,7 @@ parseFn(PEnv& penv, const SExp::List& c, UD)
 
 
 /***************************************************************************
- * Lexical Environment                                                     *
+ * Generic Lexical Environment                                             *
  ***************************************************************************/
 
 template<typename K, typename V>
@@ -386,10 +387,10 @@ struct Env : public list< map<K,V> > {
 	typedef map<K,V> Frame;
 	Env() : list<Frame>(1) {}
 	void push_front() { list<Frame>::push_front(Frame()); }
-	void def(const K& k, const V& v) {
+	const V& def(const K& k, const V& v) {
 		if (this->front().find(k) != this->front().end())
 			throw SyntaxError("Redefinition");
-		this->front()[k] = v;
+		return (this->front()[k] = v);
 	}
 	V* ref(const K& name) {
 		typename Frame::iterator s;
@@ -433,8 +434,8 @@ struct TEnv {
 	void constrain(const AST* o, AType* t) {
 		constraints.push_back(make_pair(type(o), t));
 	}
-	void           solve() { apply(unify(constraints)); }
-	void           apply(const TSubst& substs);
+	void          solve() { apply(unify(constraints)); }
+	void          apply(const TSubst& substs);
 	static TSubst unify(const Constraints& c);
 	PEnv&       penv;
 	Types       types;
@@ -481,6 +482,10 @@ ASTCall::constrain(TEnv& tenv) const
 void
 ASTDefinition::constrain(TEnv& tenv) const
 {
+	if (tup.size() != 3)
+		throw SyntaxError("\"def\" not passed 2 arguments");
+	if (!dynamic_cast<const ASTSymbol*>(tup[1]))
+		throw SyntaxError("\"def\" name is not a symbol");
 	FOREACH(TupV::const_iterator, p, tup)
 		(*p)->constrain(tenv);
 	AType* tvar = tenv.type(this);
@@ -631,8 +636,18 @@ struct CEnv {
 	string gensym(const char* base="_") {
 		ostringstream s; s << base << symID++; return s.str();
 	}
-	typedef Env<const AST*, AST*>         Code;
-	typedef Env<const ASTSymbol*, Value*> Vals;
+	void push() { code.push_front(); vals.push_front(); }
+	void pop()  { code.pop_front();  vals.pop_front(); }
+	Value* compile(AST* obj) {
+		Value** v = vals.ref(obj);
+		return (v) ? *v : vals.def(obj, obj->compile(*this));
+	}
+	void precompile(AST* obj, Value* value) {
+		assert(!vals.ref(obj));
+		vals.def(obj, value);
+	}
+	typedef Env<const AST*, AST*>   Code;
+	typedef Env<const AST*, Value*> Vals;
 	PEnv&                  penv;
 	TEnv                   tenv;
 	IRBuilder<>            builder;
@@ -693,7 +708,7 @@ ASTSymbol::compile(CEnv& cenv)
 
 	AST** c = cenv.code.ref(this);
 	if (c) {
-		Value* v = (*c)->compile(cenv);
+		Value* v = cenv.compile(*c);
 		cenv.vals.def(this, v);
 		return v;
 	}
@@ -701,17 +716,17 @@ ASTSymbol::compile(CEnv& cenv)
 	throw SyntaxError((string("Undefined symbol '") + cppstr + "'").c_str());
 }
 
+void
+ASTDefinition::lift(CEnv& cenv)
+{
+	cenv.code.def((ASTSymbol*)tup[1], tup[2]); // Define first for recursion
+	tup[2]->lift(cenv);
+}
+
 Value*
 ASTDefinition::compile(CEnv& cenv)
 {
-	if (tup.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments");
-	const ASTSymbol* sym = dynamic_cast<const ASTSymbol*>(tup[1]);
-	if (!sym) throw SyntaxError("Definition name is not a symbol");
-
-	Value* val = tup[2]->compile(cenv);
-	cenv.code.def(sym, tup[2]);
-	cenv.vals.def(sym, val);
-	return val;
+	return cenv.compile(tup[2]);
 }
 
 void
@@ -730,7 +745,7 @@ ASTCall::lift(CEnv& cenv)
 	if (!c) return;
 
 	// Extend environment with bound and typed parameters
-	cenv.code.push_front();
+	cenv.push();
 	if (c->prot->tup.size() != tup.size() - 1)
 		throw CompileError("Call to closure with mismatched arguments");
 
@@ -738,7 +753,7 @@ ASTCall::lift(CEnv& cenv)
 		cenv.code.def(c->prot->tup[i-1], tup[i]);
 
 	tup[0]->lift(cenv); // Lift called closure
-	cenv.code.pop_front(); // Restore environment
+	cenv.pop(); // Restore environment
 }
 
 Value*
@@ -749,16 +764,16 @@ ASTCall::compile(CEnv& cenv)
 		AST** val = cenv.code.ref(tup[0]);
 		c = (val) ? dynamic_cast<ASTClosure*>(*val) : c;
 	}
-	
+
 	if (!c) throw CompileError("Call to non-closure");
-	Value* v = c->compile(cenv);
+	Value* v = cenv.compile(c);
 	if (!v) throw CompileError("Callee failed to compile");
-	Function* f = dynamic_cast<Function*>(c->compile(cenv));
+	Function* f = dynamic_cast<Function*>(cenv.compile(c));
 	if (!f) throw CompileError("Callee compiled to non-function");
 
 	vector<Value*> params;
 	for (size_t i = 1; i < tup.size(); ++i)
-		params.push_back(tup[i]->compile(cenv));
+		params.push_back(cenv.compile(tup[i]));
 
 	return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp");
 }
@@ -766,7 +781,7 @@ ASTCall::compile(CEnv& cenv)
 Value*
 ASTIf::compile(CEnv& cenv)
 {
-	Value*    condV  = tup[1]->compile(cenv);
+	Value*    condV  = cenv.compile(tup[1]);
 	Function* parent = cenv.builder.GetInsertBlock()->getParent();
 
 	// Create blocks for the then and else cases.
@@ -779,14 +794,14 @@ ASTIf::compile(CEnv& cenv)
 
 	// Emit then block
 	cenv.builder.SetInsertPoint(thenBB);
-	Value* thenV = tup[2]->compile(cenv); // Can change current block, so...
+	Value* thenV = cenv.compile(tup[2]); // Can change current block, so...
 	cenv.builder.CreateBr(mergeBB);
 	thenBB = cenv.builder.GetInsertBlock(); // ... update thenBB afterwards
 
 	// Emit else block
 	parent->getBasicBlockList().push_back(elseBB);
 	cenv.builder.SetInsertPoint(elseBB);
-	Value* elseV = tup[3]->compile(cenv); // Can change current block, so...
+	Value* elseV = cenv.compile(tup[3]); // Can change current block, so...
 	cenv.builder.CreateBr(mergeBB);
 	elseBB = cenv.builder.GetInsertBlock(); // ... update elseBB afterwards
 
@@ -810,7 +825,7 @@ ASTClosure::lift(CEnv& cenv)
 			return;
 
 	assert(!func);
-	cenv.code.push_front();
+	cenv.push();
 
 	// Write function declaration
 	Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(tup[2])->ctype);
@@ -825,7 +840,8 @@ ASTClosure::lift(CEnv& cenv)
 
 	// Write function body
 	try {
-		Value* retVal = tup[2]->compile(cenv);
+		cenv.precompile(this, f); // Define our value first for recursion
+		Value* retVal = cenv.compile(tup[2]);
 		cenv.builder.CreateRet(retVal); // Finish function
 		verifyFunction(*f); // Validate generated code
 		cenv.fpm.run(*f); // Optimize function
@@ -835,12 +851,14 @@ ASTClosure::lift(CEnv& cenv)
 		throw e;
 	}
 
-	cenv.code.pop_front();
+	assert(func);
+	cenv.pop();
 }
 
 Value*
 ASTClosure::compile(CEnv& cenv)
 {
+	assert(func);
 	return func; // Function was already compiled in the lifting pass
 }
 
@@ -848,16 +866,16 @@ Value*
 ASTPrimitive::compile(CEnv& cenv)
 {
 	if (tup.size() < 3) throw SyntaxError("Too few arguments");
-	Value* a = tup[1]->compile(cenv);
-	Value* b = tup[2]->compile(cenv);
+	Value* a = cenv.compile(tup[1]);
+	Value* b = cenv.compile(tup[2]);
 
 	if (OP_IS_A(op, Instruction::BinaryOps)) {
 		const Instruction::BinaryOps bo = (Instruction::BinaryOps)op;
 		if (tup.size() == 2)
-			return tup[1]->compile(cenv);
+			return cenv.compile(tup[1]);
 		Value* val = cenv.builder.CreateBinOp(bo, a, b);
 		for (size_t i = 3; i < tup.size(); ++i)
-			val = cenv.builder.CreateBinOp(bo, val, tup[i]->compile(cenv));
+			val = cenv.builder.CreateBinOp(bo, val, cenv.compile(tup[i]));
 		return val;
 	} else if (op == Instruction::ICmp) {
 		bool isInt = cenv.tenv.type(tup[1])->str() == "(Int)";
@@ -943,7 +961,7 @@ main()
 				BasicBlock* bb = BasicBlock::Create("entry", f);
 				cenv.builder.SetInsertPoint(bb);
 				try {
-					Value* retVal = body->compile(cenv);
+					Value* retVal = cenv.compile(body);
 					cenv.builder.CreateRet(retVal); // Finish function
 					verifyFunction(*f); // Validate generated code
 					cenv.fpm.run(*f); // Optimize function
@@ -959,7 +977,7 @@ main()
 				else if (bodyT->ctype == Type::Int1Ty)
 					std::cout << ";  " << ((bool (*)())fp)();
 			} else {
-				Value* val = body->compile(cenv);
+				Value* val = cenv.compile(body);
 				std::cout << ";   " << val;
 			}
 			std::cout << " : " << cenv.tenv.type(body)->str() << endl;
-- 
cgit v1.2.1