aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--ll.cpp86
1 files changed, 52 insertions, 34 deletions
diff --git a/ll.cpp b/ll.cpp
index 368a3cd..114bab8 100644
--- a/ll.cpp
+++ b/ll.cpp
@@ -271,6 +271,7 @@ struct ASTCall : public ASTTuple {
struct ASTDefinition : public ASTCall {
ASTDefinition(const TupV& t) : ASTCall(t) {}
void constrain(TEnv& tenv) const;
+ void lift(CEnv& cenv);
Value* compile(CEnv& cenv);
};
@@ -378,7 +379,7 @@ parseFn(PEnv& penv, const SExp::List& c, UD)
/***************************************************************************
- * Lexical Environment *
+ * Generic Lexical Environment *
***************************************************************************/
template<typename K, typename V>
@@ -386,10 +387,10 @@ struct Env : public list< map<K,V> > {
typedef map<K,V> Frame;
Env() : list<Frame>(1) {}
void push_front() { list<Frame>::push_front(Frame()); }
- void def(const K& k, const V& v) {
+ const V& def(const K& k, const V& v) {
if (this->front().find(k) != this->front().end())
throw SyntaxError("Redefinition");
- this->front()[k] = v;
+ return (this->front()[k] = v);
}
V* ref(const K& name) {
typename Frame::iterator s;
@@ -433,8 +434,8 @@ struct TEnv {
void constrain(const AST* o, AType* t) {
constraints.push_back(make_pair(type(o), t));
}
- void solve() { apply(unify(constraints)); }
- void apply(const TSubst& substs);
+ void solve() { apply(unify(constraints)); }
+ void apply(const TSubst& substs);
static TSubst unify(const Constraints& c);
PEnv& penv;
Types types;
@@ -481,6 +482,10 @@ ASTCall::constrain(TEnv& tenv) const
void
ASTDefinition::constrain(TEnv& tenv) const
{
+ if (tup.size() != 3)
+ throw SyntaxError("\"def\" not passed 2 arguments");
+ if (!dynamic_cast<const ASTSymbol*>(tup[1]))
+ throw SyntaxError("\"def\" name is not a symbol");
FOREACH(TupV::const_iterator, p, tup)
(*p)->constrain(tenv);
AType* tvar = tenv.type(this);
@@ -631,8 +636,18 @@ struct CEnv {
string gensym(const char* base="_") {
ostringstream s; s << base << symID++; return s.str();
}
- typedef Env<const AST*, AST*> Code;
- typedef Env<const ASTSymbol*, Value*> Vals;
+ void push() { code.push_front(); vals.push_front(); }
+ void pop() { code.pop_front(); vals.pop_front(); }
+ Value* compile(AST* obj) {
+ Value** v = vals.ref(obj);
+ return (v) ? *v : vals.def(obj, obj->compile(*this));
+ }
+ void precompile(AST* obj, Value* value) {
+ assert(!vals.ref(obj));
+ vals.def(obj, value);
+ }
+ typedef Env<const AST*, AST*> Code;
+ typedef Env<const AST*, Value*> Vals;
PEnv& penv;
TEnv tenv;
IRBuilder<> builder;
@@ -693,7 +708,7 @@ ASTSymbol::compile(CEnv& cenv)
AST** c = cenv.code.ref(this);
if (c) {
- Value* v = (*c)->compile(cenv);
+ Value* v = cenv.compile(*c);
cenv.vals.def(this, v);
return v;
}
@@ -701,17 +716,17 @@ ASTSymbol::compile(CEnv& cenv)
throw SyntaxError((string("Undefined symbol '") + cppstr + "'").c_str());
}
+void
+ASTDefinition::lift(CEnv& cenv)
+{
+ cenv.code.def((ASTSymbol*)tup[1], tup[2]); // Define first for recursion
+ tup[2]->lift(cenv);
+}
+
Value*
ASTDefinition::compile(CEnv& cenv)
{
- if (tup.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments");
- const ASTSymbol* sym = dynamic_cast<const ASTSymbol*>(tup[1]);
- if (!sym) throw SyntaxError("Definition name is not a symbol");
-
- Value* val = tup[2]->compile(cenv);
- cenv.code.def(sym, tup[2]);
- cenv.vals.def(sym, val);
- return val;
+ return cenv.compile(tup[2]);
}
void
@@ -730,7 +745,7 @@ ASTCall::lift(CEnv& cenv)
if (!c) return;
// Extend environment with bound and typed parameters
- cenv.code.push_front();
+ cenv.push();
if (c->prot->tup.size() != tup.size() - 1)
throw CompileError("Call to closure with mismatched arguments");
@@ -738,7 +753,7 @@ ASTCall::lift(CEnv& cenv)
cenv.code.def(c->prot->tup[i-1], tup[i]);
tup[0]->lift(cenv); // Lift called closure
- cenv.code.pop_front(); // Restore environment
+ cenv.pop(); // Restore environment
}
Value*
@@ -749,16 +764,16 @@ ASTCall::compile(CEnv& cenv)
AST** val = cenv.code.ref(tup[0]);
c = (val) ? dynamic_cast<ASTClosure*>(*val) : c;
}
-
+
if (!c) throw CompileError("Call to non-closure");
- Value* v = c->compile(cenv);
+ Value* v = cenv.compile(c);
if (!v) throw CompileError("Callee failed to compile");
- Function* f = dynamic_cast<Function*>(c->compile(cenv));
+ Function* f = dynamic_cast<Function*>(cenv.compile(c));
if (!f) throw CompileError("Callee compiled to non-function");
vector<Value*> params;
for (size_t i = 1; i < tup.size(); ++i)
- params.push_back(tup[i]->compile(cenv));
+ params.push_back(cenv.compile(tup[i]));
return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp");
}
@@ -766,7 +781,7 @@ ASTCall::compile(CEnv& cenv)
Value*
ASTIf::compile(CEnv& cenv)
{
- Value* condV = tup[1]->compile(cenv);
+ Value* condV = cenv.compile(tup[1]);
Function* parent = cenv.builder.GetInsertBlock()->getParent();
// Create blocks for the then and else cases.
@@ -779,14 +794,14 @@ ASTIf::compile(CEnv& cenv)
// Emit then block
cenv.builder.SetInsertPoint(thenBB);
- Value* thenV = tup[2]->compile(cenv); // Can change current block, so...
+ Value* thenV = cenv.compile(tup[2]); // Can change current block, so...
cenv.builder.CreateBr(mergeBB);
thenBB = cenv.builder.GetInsertBlock(); // ... update thenBB afterwards
// Emit else block
parent->getBasicBlockList().push_back(elseBB);
cenv.builder.SetInsertPoint(elseBB);
- Value* elseV = tup[3]->compile(cenv); // Can change current block, so...
+ Value* elseV = cenv.compile(tup[3]); // Can change current block, so...
cenv.builder.CreateBr(mergeBB);
elseBB = cenv.builder.GetInsertBlock(); // ... update elseBB afterwards
@@ -810,7 +825,7 @@ ASTClosure::lift(CEnv& cenv)
return;
assert(!func);
- cenv.code.push_front();
+ cenv.push();
// Write function declaration
Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, cenv.tenv.type(tup[2])->ctype);
@@ -825,7 +840,8 @@ ASTClosure::lift(CEnv& cenv)
// Write function body
try {
- Value* retVal = tup[2]->compile(cenv);
+ cenv.precompile(this, f); // Define our value first for recursion
+ Value* retVal = cenv.compile(tup[2]);
cenv.builder.CreateRet(retVal); // Finish function
verifyFunction(*f); // Validate generated code
cenv.fpm.run(*f); // Optimize function
@@ -835,12 +851,14 @@ ASTClosure::lift(CEnv& cenv)
throw e;
}
- cenv.code.pop_front();
+ assert(func);
+ cenv.pop();
}
Value*
ASTClosure::compile(CEnv& cenv)
{
+ assert(func);
return func; // Function was already compiled in the lifting pass
}
@@ -848,16 +866,16 @@ Value*
ASTPrimitive::compile(CEnv& cenv)
{
if (tup.size() < 3) throw SyntaxError("Too few arguments");
- Value* a = tup[1]->compile(cenv);
- Value* b = tup[2]->compile(cenv);
+ Value* a = cenv.compile(tup[1]);
+ Value* b = cenv.compile(tup[2]);
if (OP_IS_A(op, Instruction::BinaryOps)) {
const Instruction::BinaryOps bo = (Instruction::BinaryOps)op;
if (tup.size() == 2)
- return tup[1]->compile(cenv);
+ return cenv.compile(tup[1]);
Value* val = cenv.builder.CreateBinOp(bo, a, b);
for (size_t i = 3; i < tup.size(); ++i)
- val = cenv.builder.CreateBinOp(bo, val, tup[i]->compile(cenv));
+ val = cenv.builder.CreateBinOp(bo, val, cenv.compile(tup[i]));
return val;
} else if (op == Instruction::ICmp) {
bool isInt = cenv.tenv.type(tup[1])->str() == "(Int)";
@@ -943,7 +961,7 @@ main()
BasicBlock* bb = BasicBlock::Create("entry", f);
cenv.builder.SetInsertPoint(bb);
try {
- Value* retVal = body->compile(cenv);
+ Value* retVal = cenv.compile(body);
cenv.builder.CreateRet(retVal); // Finish function
verifyFunction(*f); // Validate generated code
cenv.fpm.run(*f); // Optimize function
@@ -959,7 +977,7 @@ main()
else if (bodyT->ctype == Type::Int1Ty)
std::cout << "; " << ((bool (*)())fp)();
} else {
- Value* val = body->compile(cenv);
+ Value* val = cenv.compile(body);
std::cout << "; " << val;
}
std::cout << " : " << cenv.tenv.type(body)->str() << endl;