aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cps.cpp103
-rw-r--r--llvm.cpp2
-rw-r--r--tuplr.cpp10
-rw-r--r--tuplr.hpp20
-rw-r--r--write.cpp26
5 files changed, 125 insertions, 36 deletions
diff --git a/cps.cpp b/cps.cpp
index c1bddf2..e1a419f 100644
--- a/cps.cpp
+++ b/cps.cpp
@@ -22,55 +22,118 @@
* CPS Conversion *
***************************************************************************/
-/** (cps x) => (cont x) */
+/** (cps x cont) => (cont x) */
AST*
AST::cps(TEnv& tenv, AST* cont)
{
return tup<ACall>(loc, cont, this, 0);
}
-AST*
-ATuple::cps(TEnv& tenv, AST* cont)
-{
- ATuple* copy = tup<ATuple>(loc, NULL);
- FOREACH(const_iterator, p, *this)
- copy->push_back((*p)->cps(tenv, cont));
- return copy;
-}
-
-/** (cps (fn (a ...) body ...)) => */
+/** (cps (fn (a ...) body) cont) => (cont (fn (a ... k) (cps body k))*/
AST*
AFn::cps(TEnv& tenv, AST* cont)
{
- AFn* copy = tup<AFn>(loc, tenv.penv.sym("fn"), prot(), 0);
+ ATuple* copyProt = new ATuple(prot()->loc, *prot());
+ ASymbol* contArg = tenv.penv.gensym("_k");
+ copyProt->push_back(contArg);
+ AFn* copy = tup<AFn>(loc, tenv.penv.sym("fn"), copyProt, 0);
const_iterator p = begin();
++(++p);
for (; p != end(); ++p)
- copy->push_back((*p)->cps(tenv, cont));
- return copy;
+ copy->push_back((*p)->cps(tenv, contArg));
+ return tup<ACall>(loc, cont, copy, 0);
+}
+
+AST*
+APrimitive::cps(TEnv& tenv, AST* cont)
+{
+ return value() ? tup<ACall>(loc, cont, this, 0) : ACall::cps(tenv, cont);
}
/** (cps (f a b ...)) => (a (fn (x) (b (fn (y) ... (cont (f x y ...)) */
AST*
ACall::cps(TEnv& tenv, AST* cont)
{
- return tup<ACall>(loc, cont, this, 0);
+ std::vector< std::pair<AFn*, AST*> > funcs;
+ AFn* fn = NULL;
+ ASymbol* arg = NULL;
+
+ // Make a continuation for each element (operator and arguments)
+ ssize_t firstFn = -1;
+ ssize_t lastFn = -1;
+ for (size_t i = 0; i < size(); ++i) {
+ if (!at(i)->to<ATuple*>()) {
+ funcs.push_back(make_pair((AFn*)NULL, at(i)));
+ } else {
+ arg = tenv.penv.gensym("a");
+
+ if (firstFn == -1)
+ firstFn = i;
+
+ AFn* thisFn = tup<AFn>(loc, tenv.penv.sym("fn"),
+ tup<ATuple>(at(i)->loc, arg, tenv.penv.gensym("_k"), 0),
+ 0);
+
+ if (lastFn != -1)
+ fn->push_back(at(lastFn)->cps(tenv, thisFn));
+
+ funcs.push_back(make_pair(thisFn, arg));
+ fn = thisFn;
+ lastFn = i;
+ }
+ }
+
+ if (firstFn != -1) {
+ // Call our callee in the last argument's evaluation function
+ ACall* call = tup<ACall>(loc, 0);
+ assert(funcs.size() == size());
+ for (size_t i = 0; i < funcs.size(); ++i)
+ call->push_back(funcs[i].second);
+ if (!to<APrimitive*>())
+ call->push_back(cont);
+ else
+ call = tup<ACall>(loc, cont, call, 0);
+
+ assert(fn);
+ fn->push_back(call);
+ return at(firstFn)->cps(tenv, funcs[firstFn].first);
+ } else {
+ assert(at(0)->value());
+ ACall* ret = tup<ACall>(loc, 0);
+ for (size_t i = 0; i < size(); ++i)
+ ret->push_back(at(i));
+ if (!to<APrimitive*>())
+ ret->push_back(cont);
+ return ret;
+ }
}
/** (cps (def x y)) => (y (fn (x) (cont))) */
AST*
ADef::cps(TEnv& tenv, AST* cont)
{
- return tup<ADef>(loc, tenv.penv.sym("def"), sym(), at(2)->cps(tenv, cont), 0);
+ AST* val = at(2)->cps(tenv, cont);
+ ACall* valCall = val->to<ACall*>();
+ assert(valCall);
+ return tup<ADef>(loc, tenv.penv.sym("def"), sym(), valCall->at(1), 0);
}
/** (cps (if c t ... e)) => */
AST*
AIf::cps(TEnv& tenv, AST* cont)
{
- AFn* contFn = tup<AFn>(loc, tenv.penv.sym("if-fn"),
- new ATuple(at(1)->loc, cont, 0), 0);
- ACall* condCall = tup<ACall>(loc, contFn, 0);
- return condCall;
+ ASymbol* argSym = tenv.penv.gensym("c");
+ if (at(1)->value()) {
+ return tup<AIf>(loc, tenv.penv.sym("if"), at(1),
+ at(2)->cps(tenv, cont),
+ at(3)->cps(tenv, cont), 0);
+ } else {
+ AFn* contFn = tup<AFn>(loc, tenv.penv.sym("fn"),
+ tup<ATuple>(at(1)->loc, argSym, tenv.penv.gensym("_k"), 0),
+ tup<AIf>(loc, tenv.penv.sym("if"), argSym,
+ at(2)->cps(tenv, cont),
+ at(3)->cps(tenv, cont), 0));
+ return at(1)->cps(tenv, contFn);
+ }
}
diff --git a/llvm.cpp b/llvm.cpp
index 4fa9ca9..1f4c58e 100644
--- a/llvm.cpp
+++ b/llvm.cpp
@@ -254,7 +254,7 @@ AFn::liftCall(CEnv& cenv, const AType& argsT)
}
// Write function declaration
- const string name = (this->name == "") ? cenv.gensym("_fn") : this->name;
+ const string name = (this->name == "") ? cenv.penv.gensymstr("_fn") : this->name;
Function* f = llFunc(cenv.engine()->startFunction(cenv, name,
thisType->at(thisType->size()-1)->to<AType*>(),
*protT, argNames));
diff --git a/tuplr.cpp b/tuplr.cpp
index 5b1a357..52d8b7d 100644
--- a/tuplr.cpp
+++ b/tuplr.cpp
@@ -265,11 +265,13 @@ eval(CEnv& cenv, const string& name, istream& is)
// Create function for top-level of program
CFunction f = cenv.engine()->startFunction(cenv, "main", resultType, ATuple(cursor));
-
+
// Print CPS form
CValue val = NULL;
- /*for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i)
- cout << "CPS: " << i->second->cps(cenv.tenv, cenv.penv.sym("cont")) << endl;*/
+ /*for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i) {
+ cout << "CPS: " << endl;
+ pprint(cout, i->second->cps(cenv.tenv, cenv.penv.sym("cont")));
+ }*/
// Compile all expressions into it
for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i)
@@ -319,7 +321,7 @@ repl(CEnv& cenv)
CFunction f = NULL;
try {
// Create anonymous function to insert code into
- f = cenv.engine()->startFunction(cenv, cenv.gensym("_repl"), bodyT, ATuple(cursor));
+ f = cenv.engine()->startFunction(cenv, cenv.penv.gensymstr("_repl"), bodyT, ATuple(cursor));
CValue retVal = cenv.compile(body);
cenv.engine()->finishFunction(cenv, f, retVal);
cenv.out << cenv.engine()->call(cenv, f, bodyT);
diff --git a/tuplr.hpp b/tuplr.hpp
index b54f72a..37a09db 100644
--- a/tuplr.hpp
+++ b/tuplr.hpp
@@ -193,6 +193,7 @@ extern ostream& operator<<(ostream& out, const AST* ast);
struct AST : public Object {
AST(Cursor c=Cursor()) : loc(c) {}
virtual ~AST() {}
+ virtual bool value() const { return true; }
virtual bool operator==(const AST& o) const = 0;
virtual bool contains(const AST* child) const { return false; }
virtual void constrain(TEnv& tenv, Constraints& c) const {}
@@ -254,6 +255,7 @@ struct ATuple : public AST, public vector<AST*> {
for (AST* a = va_arg(args, AST*); a; a = va_arg(args, AST*))
push_back(a);
}
+ bool value() const { return false; }
bool operator==(const AST& rhs) const {
const ATuple* rt = rhs.to<const ATuple*>();
if (!rt || rt->size() != size()) return false;
@@ -271,7 +273,6 @@ struct ATuple : public AST, public vector<AST*> {
return false;
}
void constrain(TEnv& tenv, Constraints& c) const;
- AST* cps(TEnv& tenv, AST* cont);
void lift(CEnv& cenv) { FOREACH(iterator, t, *this) (*t)->lift(cenv); }
CValue compile(CEnv& cenv) { throw Error(loc, "tuple compiled"); }
@@ -408,6 +409,7 @@ struct ADef : public ACall {
/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct AIf : public ACall {
AIf(const SExp& e, const ATuple& t) : ACall(e, t) {}
+ AIf(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {}
void constrain(TEnv& tenv, Constraints& c) const;
AST* cps(TEnv& tenv, AST* cont);
CValue compile(CEnv& cenv);
@@ -416,7 +418,14 @@ struct AIf : public ACall {
/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct APrimitive : public ACall {
APrimitive(const SExp& e, const ATuple& t) : ACall(e, t) {}
+ bool value() const {
+ for (size_t i = 1; i < size(); ++i)
+ if (!at(i)->value())
+ return false;;
+ return true;
+ }
void constrain(TEnv& tenv, Constraints& c) const;
+ AST* cps(TEnv& tenv, AST* cont);
CValue compile(CEnv& cenv);
};
@@ -427,6 +436,7 @@ struct APrimitive : public ACall {
/// Parse Time Environment (really just a symbol table)
struct PEnv : private map<const string, ASymbol*> {
+ PEnv() : symID(0) {}
typedef AST* (*PF)(PEnv&, const SExp&, void*); ///< Parse Function
typedef SExp (*MF)(PEnv&, const SExp&); ///< Macro Function
struct Handler { Handler(PF f, void* a=0) : func(f), arg(a) {} PF func; void* arg; };
@@ -448,6 +458,8 @@ struct PEnv : private map<const string, ASymbol*> {
map<string, MF>::const_iterator i = macros.find(s);
return (i != macros.end()) ? i->second : NULL;
}
+ string gensymstr(const char* s="_") { return (format("%s%d") % s % symID++).str(); }
+ ASymbol* gensym(const char* s="_") { return sym(gensymstr(s)); }
ASymbol* sym(const string& s, Cursor c=Cursor()) {
const const_iterator i = find(s);
if (i != end()) {
@@ -489,6 +501,7 @@ struct PEnv : private map<const string, ASymbol*> {
}
return sym(exp.atom, exp.loc);
}
+ unsigned symID;
};
@@ -542,6 +555,7 @@ struct TEnv : public Env< const ASymbol*, pair<AST*, AType*> > {
ASymbol* sym = ast->to<ASymbol*>();
return (sym && sym->addr) ? ref(sym)->first : ast;
}
+
static Subst unify(const Constraints& c);
typedef map<const AST*, AType*> Vars;
@@ -575,7 +589,7 @@ void tuplr_free_engine(Engine* engine);
/// Compile-Time Environment
struct CEnv {
CEnv(PEnv& p, TEnv& t, Engine* e, ostream& os=std::cout, ostream& es=std::cerr)
- : out(os), err(es), penv(p), tenv(t), symID(0), _engine(e)
+ : out(os), err(es), penv(p), tenv(t), _engine(e)
{}
~CEnv() { Object::pool.collect(GC::Roots()); }
@@ -584,7 +598,6 @@ struct CEnv {
typedef Env<const AST*, CValue> Vals;
Engine* engine() { return _engine; }
- string gensym(const char* s="_") { return (format("%s%d") % s % symID++).str(); }
void push() { tenv.push(); vals.push(); }
void pop() { tenv.pop(); vals.pop(); }
void precompile(AST* obj, CValue value) { vals.def(obj, value); }
@@ -610,7 +623,6 @@ struct CEnv {
TEnv& tenv;
Vals vals;
- unsigned symID;
Subst tsubst;
map<string,string> args;
diff --git a/write.cpp b/write.cpp
index 95079ad..ba000f3 100644
--- a/write.cpp
+++ b/write.cpp
@@ -59,18 +59,30 @@ operator<<(ostream& out, const AST* ast)
void
pprint_internal(ostream& out, const AST* ast, unsigned indent)
{
- out << string().insert(0, indent, ' ');
const ATuple* tup = ast->to<const ATuple*>();
- if (tup) {
+ if (tup && tup->size() > 0) {
const string head = tup->at(0)->str();
- out << "(" << head;
- if (tup->size() > 1)
- out << " " << tup->at(1);
- for (size_t i = 2; i != tup->size(); ++i) {
- out << endl;
+ ASymbol* headSym = tup->at(0)->to<ASymbol*>();
+ out << "(";
+ pprint_internal(out, tup->at(0), indent);
+ if (tup->size() > 1) {
+ out << " ";
+ if (headSym && headSym->cppstr == "fn")
+ out << tup->at(1) << " ";
+ else
+ pprint_internal(out, tup->at(1), indent + head.length() + 1);
+ }
+ for (size_t i = 2; i < tup->size(); ++i) {
+ //if (!headSym || headSym->cppstr != "def")
+ out << endl;
+ //else
+ // out << " ";
+ out << string().insert(0, indent, ' ');
pprint_internal(out, tup->at(i), indent + head.length() + 2);
}
out << ")";
+ if (headSym && headSym->cppstr == "fn")
+ out << endl << string().insert(0, indent, ' ');
} else {
out << ast;
}