aboutsummaryrefslogtreecommitdiffstats
path: root/ll.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'll.cpp')
-rw-r--r--ll.cpp665
1 files changed, 357 insertions, 308 deletions
diff --git a/ll.cpp b/ll.cpp
index 823f64a..f3082e1 100644
--- a/ll.cpp
+++ b/ll.cpp
@@ -123,9 +123,82 @@ readExpression(std::istream& in)
/***************************************************************************
+ * Environment *
+ ***************************************************************************/
+
+class AST;
+class ASTSymbol;
+
+/// Generic Recursive Environment (stack of key:value dictionaries)
+template<typename K, typename V>
+struct Env : public list< map<K,V> > {
+ Env() : list< map<K, V> >(1) {}
+ void push() { this->push_front(map<K,V>()); }
+ void push(const map<K,V>& frame) { this->push_front(frame); }
+ map<K,V>& pop() {
+ map<K,V>& front = this->front();
+ this->pop_front();
+ return front;
+ }
+ void def(const K& k, const V& v) {
+ if (this->front().find(k) != this->front().end())
+ throw SyntaxError("Redefinition");
+ this->front()[k] = v;
+ }
+ V* ref(const K& name) {
+ typename Env::iterator i = this->begin();
+ for (; i != this->end(); ++i) {
+ typename map<K,V>::iterator s = i->find(name);
+ if (s != i->end())
+ return &s->second;
+ }
+ return 0;
+ }
+};
+
+class PEnv;
+
+/// Compile-time environment
+struct CEnv {
+ CEnv(PEnv& p, Module* m, const TargetData* target)
+ : penv(p), module(m), emp(module), fpm(&emp), symID(0), tID(0)
+ {
+ // Set up the optimizer pipeline.
+ // Register info about how the target lays out data structures.
+ fpm.add(new TargetData(*target));
+ // Do simple "peephole" and bit-twiddling optimizations.
+ fpm.add(createInstructionCombiningPass());
+ // Reassociate expressions.
+ fpm.add(createReassociatePass());
+ // Eliminate Common SubExpressions.
+ fpm.add(createGVNPass());
+ // Simplify control flow graph (delete unreachable blocks, etc).
+ fpm.add(createCFGSimplificationPass());
+ }
+ string gensym(const char* base="_") {
+ ostringstream s; s << base << symID++; return s.str();
+ }
+ typedef Env<const AST*, AST*> Code;
+ typedef Env<const ASTSymbol*, Value*> Vals;
+
+ PEnv& penv;
+ IRBuilder<> builder;
+ Module* module;
+ ExistingModuleProvider emp;
+ FunctionPassManager fpm;
+ unsigned symID;
+ unsigned tID;
+ Code code;
+ Vals vals;
+};
+
+
+
+/***************************************************************************
* Abstract Syntax Tree *
***************************************************************************/
+struct AType;
struct TypeError : public Error { TypeError (const char* m) : Error(m) {} };
struct CEnv; ///< Compile Time Environment
@@ -133,148 +206,187 @@ struct CEnv; ///< Compile Time Environment
/// Base class for all AST nodes
struct AST {
virtual ~AST() {}
- virtual const Type* type(CEnv& cenv) const = 0;
- virtual Value* compile(CEnv& cenv) = 0;
+ virtual string str(CEnv& cenv) const = 0;
+ virtual AType* type(CEnv& cenv) = 0;
+ virtual Value* compile(CEnv& cenv) = 0;
+ virtual void lift(CEnv& cenv) {}
+};
+
+/// Symbol, e.g. "a"
+struct ASTSymbol : public AST {
+ ASTSymbol(const string& s) : cppstr(s) {}
+ std::string str(CEnv&) const { return cppstr; }
+ AType* type(CEnv& cenv);
+ Value* compile(CEnv& cenv);
+private:
+ const string cppstr;
+};
+
+/// Tuple (heterogeneous sequence of known length), e.g. "(a b c)"
+struct ASTTuple : public AST {
+ ASTTuple(vector<AST*> t=vector<AST*>()) : tup(t) {}
+ string str(CEnv& cenv) const {
+ string ret = "(";
+ for (size_t i = 0; i != tup.size(); ++i)
+ ret += tup[i]->str(cenv) + ((i != tup.size() - 1) ? " " : "");
+ ret.append(")");
+ return ret;
+ }
+ void lift(CEnv& cenv) {
+ FOREACH(vector<AST*>::iterator, t, tup)
+ (*t)->lift(cenv);
+ }
+ AType* type(CEnv& cenv);
+ Value* compile(CEnv& cenv) { return NULL; }
+ vector<AST*> tup;
+};
+
+/// TExpr ::= (TName TExpr*) | ?Num
+struct AType : public ASTTuple {
+ AType(unsigned i) : var(true), ctype(0), id(id) {}
+ AType(const string& n, const Type* t) : var(false), ctype(t) {
+ tup.push_back(new ASTSymbol(n));
+ }
+ AType(const vector<AST*>& t) : ASTTuple(t), var(false), ctype(0) {}
+ inline bool operator==(const AType& t) const { return tup[0] == t.tup[0]; }
+ inline bool operator!=(const AType& t) const { return tup[0] != t.tup[0]; }
+ string str(CEnv& cenv) const { return var ? "?" : ASTTuple::str(cenv); }
+ AType* type(CEnv& cenv) { return this; }
+ Value* compile(CEnv& cenv) { return NULL; }
+ bool var;
+ const Type* ctype;
+ unsigned id;
};
/// Literal
template<typename VT>
struct ASTLiteral : public AST {
ASTLiteral(VT v) : val(v) {}
- const Type* type(CEnv& cenv) const;
- Value* compile(CEnv& cenv);
+ string str(CEnv& env) const { return "(Literal)"; }
+ AType* type(CEnv& cenv);
+ Value* compile(CEnv& cenv);
const VT val;
};
-
-#define LITERAL(CT, VT, COMPILED) \
-template<> const Type* \
-ASTLiteral<CT>::type(CEnv& cenv) const { return VT; } \
- \
+#define LITERAL(CT, VT, NAME, COMPILED) \
+template<> string \
+ASTLiteral<CT>::str(CEnv& cenv) const { return NAME; } \
+template<> AType* \
+ASTLiteral<CT>::type(CEnv& cenv) { return new AType(NAME, VT); } \
template<> Value* \
ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); }
-/// Literal template specialisations
-LITERAL(int32_t, Type::Int32Ty, ConstantInt::get(type(cenv), val, true));
-LITERAL(float, Type::FloatTy, ConstantFP::get(type(cenv), val));
-LITERAL(bool, Type::Int1Ty, ConstantInt::get(type(cenv), val, false));
+/// Literal template instantiations
+LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, true));
+LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val));
+LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false));
-/// Symbol, e.g. "a"
-struct ASTSymbol : public AST {
- ASTSymbol(const string& n) : name(n) {}
- virtual const Type* type(CEnv& cenv) const;
- virtual Value* compile(CEnv& cenv);
- const string name;
+typedef unsigned UD; // User Data passed to registered parse functions
+
+// Parse Time Environment (symbol table)
+struct PEnv : private map<const string, ASTSymbol*> {
+ typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function
+ struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; };
+ map<const ASTSymbol*, Parser> parsers;
+ void reg(const ASTSymbol* s, const Parser& p) {
+ parsers.insert(make_pair(s, p));
+ }
+ const Parser* parser(const ASTSymbol* s) const {
+ map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s);
+ return (i != parsers.end()) ? &i->second : NULL;
+ }
+ ASTSymbol* sym(const string& s) {
+ const const_iterator i = find(s);
+ return ((i != end())
+ ? i->second
+ : insert(make_pair(s, new ASTSymbol(s))).first->second);
+ }
+};
+
+/// Closure (first-class function with captured lexical bindings)
+struct ASTClosure : public AST {
+ ASTClosure(ASTTuple* p, AST* b) : prot(p), body(b), func(0) {}
+ string str(CEnv& env) const { return "(fn)"; }
+ AType* type(CEnv& cenv) {
+ vector<AST*> texp(3);
+ texp[0] = cenv.penv.sym("Fn");
+ texp[1] = prot;
+ texp[2] = body;
+ return new AType(texp);
+ }
+ Value* compile(CEnv& cenv);
+ void lift(CEnv& cenv);
+ ASTTuple* const prot;
+ AST* const body;
+ vector<const ASTSymbol*> bindings;
+private:
+ Function* func;
};
/// Function call/application, e.g. "(func arg1 arg2)"
-struct ASTCall : public AST {
- ASTCall(const vector<AST*>& c) : code(c) {}
- virtual const Type* type(CEnv& cenv) const {
- AST* func = code[0];
- const FunctionType* ftype = dynamic_cast<const FunctionType*>(func->type(cenv));
- if (!ftype) throw TypeError(string("Call to non-function type :: ")
- .append(func->type(cenv)->getDescription()).c_str());
- return ftype->getReturnType();
+struct ASTCall : public ASTTuple {
+ ASTCall(const vector<AST*>& t) : ASTTuple(t) {}
+ AType* type(CEnv& cenv) {
+ AST* callee = tup[0];
+ ASTSymbol* sym = dynamic_cast<ASTSymbol*>(tup[0]);
+ if (sym) {
+ AST** val = cenv.code.ref(sym);
+ if (val)
+ callee = *val;
+ }
+ ASTClosure* c = dynamic_cast<ASTClosure*>(callee);
+ if (!c) throw TypeError("Call to non-closure");
+ return c->body->type(cenv);
}
- virtual Value* compile(CEnv& cenv);
- const vector<AST*> code;
+ void lift(CEnv& cenv);
+ Value* compile(CEnv& cenv);
};
/// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))"
struct ASTDefinition : public ASTCall {
ASTDefinition(const vector<AST*>& c) : ASTCall(c) {}
- virtual const Type* type(CEnv& cenv) const { return code[2]->type(cenv); }
- virtual Value* compile(CEnv& cenv);
+ AType* type(CEnv& cenv) { return tup[2]->type(cenv); }
+ Value* compile(CEnv& cenv);
};
/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct ASTIf : public ASTCall {
ASTIf(const vector<AST*>& c) : ASTCall(c) {}
- virtual const Type* type(CEnv& cenv) const {
- const Type* cT = code[1]->type(cenv);
- const Type* tT = code[2]->type(cenv);
- const Type* eT = code[3]->type(cenv);
- if (cT != Type::Int1Ty) throw TypeError("If condition is not a boolean");
- if (tT != eT) throw TypeError("If branches have different types");
+ AType* type(CEnv& cenv) {
+ AType* cT = tup[1]->type(cenv);
+ AType* tT = tup[2]->type(cenv);
+ AType* eT = tup[3]->type(cenv);
+ if (cT->ctype != Type::Int1Ty) throw TypeError("If condition is not a boolean");
+ if (*tT != *eT) throw TypeError("If branches have different types");
return tT;
}
- virtual Value* compile(CEnv& cenv);
+ Value* compile(CEnv& cenv);
};
/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct ASTPrimitive : public ASTCall {
ASTPrimitive(const vector<AST*>& c, Instruction::BinaryOps o) : ASTCall(c), op(o) {}
- virtual const Type* type(CEnv& cenv) const { return Type::FloatTy; }
- virtual Value* compile(CEnv& cenv);
- Instruction::BinaryOps op;
-};
-
-/// Function prototype (actual LLVM IR function prototype)
-struct ASTPrototype {
- ASTPrototype(vector<AST*> p=vector<AST*>()) : params(p) {}
- vector<const Type*> argsType(CEnv& cenv) {
- vector<const Type*> types;
- FOREACH(vector<AST*>::const_iterator, p, params)
- types.push_back((*p)->type(cenv));
- return types;
- }
- virtual const Type* type(CEnv& cenv) const { return NULL; }
- Function* compile(CEnv& cenv, FunctionType* type, const string& name);
- string name;
- vector<AST*> params;
-};
-
-/// Closure (first-class function with captured lexical bindings)
-struct ASTClosure : public AST {
- ASTClosure(ASTPrototype* p, AST* b) : prot(p), body(b), func(0) {}
- virtual const Type* type(CEnv& cenv) const {
- return FunctionType::get(body->type(cenv), prot->argsType(cenv), false);
+ AType* type(CEnv& cenv) {
+ if (tup.size() <= 1) throw SyntaxError("Primitive call with no arguments");
+ return tup[1]->type(cenv); // FIXME: Ensure argument types are equivalent
}
- virtual Value* compile(CEnv& cenv);
- virtual void lift(CEnv& cenv);
- ASTPrototype* const prot;
- AST* const body;
- vector<const ASTSymbol*> bindings;
-private:
- Function* func;
-};
-
-/// Function definition (actual LLVM IR function)
-struct ASTFunction {
- ASTFunction(ASTPrototype* p, AST* b) : prot(p), body(b) {}
- Function* compile(CEnv& cenv, const string& name);
- ASTPrototype* const prot;
- AST* const body;
+ Value* compile(CEnv& cenv);
+ Instruction::BinaryOps op;
};
+AType*
+ASTTuple::type(CEnv& cenv)
+{
+ vector<AST*> texp;
+ FOREACH(vector<AST*>::const_iterator, p, tup)
+ texp.push_back((*p)->type(cenv));
+ return new AType(texp);
+}
/***************************************************************************
* Parser - S-Expressions (SExp) -> AST Nodes (AST) *
***************************************************************************/
-typedef unsigned UD; // User Data passed to registered parse functions
-
-// Parse Time Environment (just a symbol table)
-struct PEnv : private map<const string, ASTSymbol*> {
- typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function
- struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; };
- map<const ASTSymbol*, Parser> parsers;
- void reg(const ASTSymbol* s, const Parser& p) {
- parsers.insert(make_pair(s, p));
- }
- const Parser* parser(const ASTSymbol* s) const {
- map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s);
- return (i != parsers.end()) ? &i->second : NULL;
- }
- ASTSymbol* sym(const string& s) {
- const const_iterator i = find(s);
- return ((i != end())
- ? i->second
- : insert(make_pair(s, new ASTSymbol(s))).first->second);
- }
-};
-
/// The fundamental parser method
static AST*
parseExpression(PEnv& penv, const SExp& exp)
@@ -330,9 +442,9 @@ static AST*
parsePrim(PEnv& penv, const list<SExp>& c, UD data)
{ return new ASTPrimitive(pmap(penv, c), (Instruction::BinaryOps)data); }
-static ASTPrototype*
+static ASTTuple*
parsePrototype(PEnv& penv, const SExp& e, UD)
- { return new ASTPrototype(pmap(penv, e.list)); }
+ { return new ASTTuple(pmap(penv, e.list)); }
static AST*
parseFn(PEnv& penv, const list<SExp>& c, UD)
@@ -348,143 +460,126 @@ parseFn(PEnv& penv, const list<SExp>& c, UD)
* Code Generation *
***************************************************************************/
-/// Generic Recursive Environment (stack of key:value dictionaries)
-template<typename K, typename V>
-struct Env : public list< map<K,V> > {
- Env() : list< map<K, V> >(1) {}
- void push() { this->push_front(map<K,V>()); }
- void push(const map<K,V>& frame) { this->push_front(frame); }
- map<K,V>& pop() {
- map<K,V>& front = this->front();
- this->pop_front();
- return front;
- }
- void def(const K& k, const V& v) {
- if (this->front().find(k) != this->front().end())
- std::cerr << "WARNING: Redefinition: " << k << endl;
- this->front()[k] = v;
- }
- V* ref(const K& name) {
- typename Env::iterator i = this->begin();
- for (; i != this->end(); ++i) {
- typename map<K,V>::iterator s = i->find(name);
- if (s != i->end())
- return &s->second;
- }
- return 0;
- }
-};
+struct CompileError : public Error { CompileError(const char* m) : Error(m) {} };
-/// Compile-time environment
-struct CEnv {
- CEnv(Module* m, const TargetData* target)
- : module(m), provider(module), fpm(&provider), symID(0)
- {
- // Set up the optimizer pipeline.
- // Register info about how the target lays out data structures.
- fpm.add(new TargetData(*target));
- // Do simple "peephole" and bit-twiddling optimizations.
- fpm.add(createInstructionCombiningPass());
- // Reassociate expressions.
- fpm.add(createReassociatePass());
- // Eliminate Common SubExpressions.
- fpm.add(createGVNPass());
- // Simplify control flow graph (delete unreachable blocks, etc).
- fpm.add(createCFGSimplificationPass());
- }
- string gensym(const char* base="_") {
- ostringstream s; s << base << symID++; return s.str();
- }
- void def(const ASTSymbol* sym, AST* expr) {
- types.def(sym, expr->type(*this));
- code.def(sym, expr);
+static Function*
+compileFunction(CEnv& cenv, const std::string& name, ASTTuple& prot, const Type* retT)
+{
+ Function::LinkageTypes linkage = Function::ExternalLinkage;
+
+ const vector<AST*>& texp = prot.type(cenv)->tup;
+ vector<const Type*> cprot;
+ for (size_t i = 0; i < texp.size(); ++i) {
+ const Type* t = texp[i]->type(cenv)->ctype;
+ if (!t) throw CompileError("Function prototype contains NULL");
+ cprot.push_back(t);
}
- typedef Env<const ASTSymbol*, const Type*> Types;
- typedef Env<const ASTSymbol*, AST*> Code;
- typedef Env<const ASTSymbol*, Value*> Vals;
- IRBuilder<> builder;
- Module* module;
- ExistingModuleProvider provider;
- FunctionPassManager fpm;
- size_t symID;
- Types types;
- Code code;
- Vals vals;
-};
+ FunctionType* fT = FunctionType::get(retT, cprot, false);
+ Function* f = Function::Create(fT, linkage, name, cenv.module);
-static void
-lambdaLift(CEnv& env, AST* ast)
-{
- if (ASTClosure* closure = dynamic_cast<ASTClosure*>(ast)) {
- lambdaLift(env, closure->body);
- closure->lift(env);
- } else if (ASTCall* call = dynamic_cast<ASTCall*>(ast)) {
- FOREACH(vector<AST*>::const_iterator, a, call->code)
- lambdaLift(env, *a);
+ if (f->getName() != name) {
+ f->eraseFromParent();
+ throw CompileError("Function redefined");
}
+
+ // Set argument names in generated code
+ Function::arg_iterator a = f->arg_begin();
+ for (size_t i = 0; i != prot.tup.size(); ++a, ++i)
+ a->setName(prot.tup[i]->str(cenv));
+
+ return f;
}
-const Type*
-ASTSymbol::type(CEnv& cenv) const
+AType*
+ASTSymbol::type(CEnv& cenv)
{
- const Type** t = cenv.types.ref(this);
- if (t) {
- return *t;
- } else {
- //std::cerr << "WARNING: Untyped symbol: " << name << endl;
- return Type::FloatTy;
- }
+ AST** t = cenv.code.ref(this);
+ return t ? (*t)->type(cenv) : new AType(cenv.tID++);
}
Value*
ASTSymbol::compile(CEnv& cenv)
{
- Value*const* v = cenv.vals.ref(this);
+ Value** v = cenv.vals.ref(this);
if (v)
return *v;
- AST*const* c = cenv.code.ref(this);
+ AST** c = cenv.code.ref(this);
if (c) {
Value* v = (*c)->compile(cenv);
cenv.vals.def(this, v);
return v;
}
- throw SyntaxError((string("Undefined symbol '") + name + "'").c_str());
+ throw SyntaxError((string("Undefined symbol '") + cppstr + "'").c_str());
}
Value*
ASTDefinition::compile(CEnv& cenv)
{
- if (code.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments");
- const ASTSymbol* sym = dynamic_cast<const ASTSymbol*>(code[1]);
+ if (tup.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments");
+ const ASTSymbol* sym = dynamic_cast<const ASTSymbol*>(tup[1]);
if (!sym) throw SyntaxError("Definition name is not a symbol");
- Value* val = code[2]->compile(cenv);
- cenv.types.def(sym, code[2]->type(cenv));
- cenv.code.def(sym, code[2]);
+ Value* val = tup[2]->compile(cenv);
+ cenv.code.def(sym, tup[2]);
cenv.vals.def(sym, val);
return val;
}
+void
+ASTCall::lift(CEnv& cenv)
+{
+ ASTClosure* c = dynamic_cast<ASTClosure*>(tup[0]);
+ if (!c) {
+ AST** val = cenv.code.ref(tup[0]);
+ c = (val) ? dynamic_cast<ASTClosure*>(*val) : c;
+ }
+
+ if (!c) {
+ ASTTuple::lift(cenv);
+ return;
+ }
+
+ std::cout << "Lifting call to closure" << endl;
+
+ // Lift arguments
+ for (size_t i = 1; i < tup.size(); ++i)
+ tup[i]->lift(cenv);
+
+ // Extend environment with bound and typed parameters
+ cenv.code.push();
+ if (c->prot->tup.size() != tup.size() - 1)
+ throw CompileError("Call to closure with mismatched arguments");
+
+ for (size_t i = 1; i < tup.size(); ++i)
+ cenv.code.def(c->prot->tup[i-1], tup[i]);
+
+ // Lift callee closure
+ tup[0]->lift(cenv);
+
+ cenv.code.pop();
+}
+
Value*
ASTCall::compile(CEnv& cenv)
{
- AST* func = code[0];
- AST** closure = cenv.code.ref((ASTSymbol*)func);
- assert(closure);
- ASTClosure* c = dynamic_cast<ASTClosure*>(*closure);
- assert(c);
- Function* f = dynamic_cast<Function*>(func->compile(cenv));
- if (!f) throw SyntaxError("Call to non-function");
+ ASTClosure* c = dynamic_cast<ASTClosure*>(tup[0]);
+ if (!c) {
+ AST** val = cenv.code.ref(tup[0]);
+ c = (val) ? dynamic_cast<ASTClosure*>(*val) : c;
+ }
+
+ if (!c) throw CompileError("Call to non-closure");
+ Value* v = c->compile(cenv);
+ if (!v) throw SyntaxError("Callee failed to compile");
+ Function* f = dynamic_cast<Function*>(c->compile(cenv));
+ if (!f) throw SyntaxError("Callee compiled to non-function");
vector<Value*> params;
- for (size_t i = 1; i < code.size(); ++i)
- params.push_back(code[i]->compile(cenv));
-
- for (size_t i = 0; i < c->bindings.size(); ++i)
- std::cout << "BINDING: " << c->bindings[i]->name << endl;
+ for (size_t i = 1; i < tup.size(); ++i)
+ params.push_back(tup[i]->compile(cenv));
return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp");
}
@@ -492,7 +587,7 @@ ASTCall::compile(CEnv& cenv)
Value*
ASTIf::compile(CEnv& cenv)
{
- Value* condV = code[1]->compile(cenv);
+ Value* condV = tup[1]->compile(cenv);
Function* parent = cenv.builder.GetInsertBlock()->getParent();
// Create blocks for the then and else cases.
@@ -505,7 +600,7 @@ ASTIf::compile(CEnv& cenv)
// Emit then value.
cenv.builder.SetInsertPoint(thenBB);
- Value* thenV = code[2]->compile(cenv);
+ Value* thenV = tup[2]->compile(cenv);
cenv.builder.CreateBr(mergeBB);
// compile of 'Then' can change the current block, update thenBB
@@ -514,7 +609,7 @@ ASTIf::compile(CEnv& cenv)
// Emit else block.
parent->getBasicBlockList().push_back(elseBB);
cenv.builder.SetInsertPoint(elseBB);
- Value* elseV = code[3]->compile(cenv);
+ Value* elseV = tup[3]->compile(cenv);
cenv.builder.CreateBr(mergeBB);
// compile of 'Else' can change the current block, update elseBB
@@ -523,7 +618,7 @@ ASTIf::compile(CEnv& cenv)
// Emit merge block.
parent->getBasicBlockList().push_back(mergeBB);
cenv.builder.SetInsertPoint(mergeBB);
- PHINode* pn = cenv.builder.CreatePHI(type(cenv), "iftmp");
+ PHINode* pn = cenv.builder.CreatePHI(type(cenv)->ctype, "iftmp");
pn->addIncoming(thenV, thenBB);
pn->addIncoming(elseV, elseBB);
@@ -535,20 +630,10 @@ ASTClosure::lift(CEnv& cenv)
{
assert(!func);
- //set<const ASTSymbol*> unbound;
- const ASTCall* call = dynamic_cast<const ASTCall*>(body);
- if (call) {
- std::cout << "LIFT CALL BODY\n";
- }
-
-#if 0
- Env<const ASTSymbol*, const AST*> paramsEnv;
- for (vector<const ASTSymbol*>::const_iterator p = prot->params.begin();
- p != prot->params.end(); ++p) {
- //paramsEnv[*p] = NULL;
- std::cout << "PARAM: " << (*p)->name << endl;
- }
-#endif
+ // Can't lift a closure with variable types (lift later when called)
+ for (size_t i = 0; i < prot->tup.size(); ++i)
+ if (prot->tup[i]->type(cenv)->var)
+ return;
cenv.code.push();
@@ -557,21 +642,19 @@ ASTClosure::lift(CEnv& cenv)
AST** obj = cenv.code.ref(sym);
if (!obj) {
std::cout << "UNDEFINED SYMBOL BODY\n";
- prot->params.push_back(sym);
+ prot->tup.push_back(sym);
bindings.push_back(sym);
}
}
// Write function declaration
- Function* f = prot->compile(cenv,
- FunctionType::get(body->type(cenv), prot->argsType(cenv), false),
- cenv.gensym("_fn"));
+ Function* f = compileFunction(cenv, cenv.gensym("_fn"), *prot, body->type(cenv)->ctype);
BasicBlock* bb = BasicBlock::Create("entry", f);
cenv.builder.SetInsertPoint(bb);
-
+
// Bind argument values in CEnv
vector<Value*> args;
- vector<AST*>::const_iterator p = prot->params.begin();
+ vector<AST*>::const_iterator p = prot->tup.begin();
for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a);
@@ -582,7 +665,7 @@ ASTClosure::lift(CEnv& cenv)
verifyFunction(*f); // Validate generated code
cenv.fpm.run(*f); // Optimize function
func = f;
- } catch (SyntaxError e) {
+ } catch (exception e) {
f->eraseFromParent(); // Error reading body, remove function
throw e;
}
@@ -594,77 +677,16 @@ Value*
ASTClosure::compile(CEnv& cenv)
{
// Function was already compiled in the lifting pass
- assert(func);
return func;
}
-Function*
-ASTPrototype::compile(CEnv& cenv, FunctionType* FT, const std::string& n)
-{
- name = n;
- Function::LinkageTypes linkage = Function::ExternalLinkage;
- Function* f = Function::Create(FT, linkage, name, cenv.module);
-
- // If F conflicted, there was already something named 'Name'.
- // If it has a body, don't allow redefinition.
- if (f->getName() != name) {
- // Delete the one we just made and get the existing one.
- f->eraseFromParent();
- f = cenv.module->getFunction(name);
-
- // If F already has a body, reject this.
- if (!f->empty()) throw SyntaxError("Function redefined");
-
- // If F took a different number of args, reject.
- if (f->arg_size() != params.size())
- throw SyntaxError("Function redefined with mismatched arguments");
- }
-
- // Set argument names in generated code
- Function::arg_iterator a = f->arg_begin();
- for (size_t i = 0; i != params.size(); ++a, ++i) {
- assert(params[i]);
- ASTSymbol* sym = dynamic_cast<ASTSymbol*>(params[i]);
- a->setName(sym ? sym->name : cenv.gensym("_a"));
- }
-
- return f;
-}
-
-Function*
-ASTFunction::compile(CEnv& cenv, const string& name)
-{
- const Type* bodyT = body->type(cenv);
- if (dynamic_cast<const FunctionType*>(bodyT)) {
- std::cout << "First class function alert" << endl;
- bodyT = PointerType::get(bodyT, 0);
- }
- FunctionType* fT = FunctionType::get(bodyT, prot->argsType(cenv), false);
- Function* f = prot->compile(cenv, fT, name);
-
- // Create a new basic block to start insertion into.
- BasicBlock* bb = BasicBlock::Create("entry", f);
- cenv.builder.SetInsertPoint(bb);
-
- try {
- Value* retVal = body->compile(cenv);
- cenv.builder.CreateRet(retVal); // Finish function
- verifyFunction(*f); // Validate generated code
- cenv.fpm.run(*f); // Optimize function
- return f;
- } catch (SyntaxError e) {
- f->eraseFromParent(); // Error reading body, remove function
- throw e;
- }
-}
-
Value*
ASTPrimitive::compile(CEnv& cenv)
{
size_t np = 0;
- vector<Value*> params(code.size() - 1);
- vector<AST*>::const_iterator a = code.begin();
- for (++a; a != code.end(); ++a)
+ vector<Value*> params(tup.size() - 1);
+ vector<AST*>::const_iterator a = tup.begin();
+ for (++a; a != tup.end(); ++a)
params[np++] = (*a)->compile(cenv);
switch (params.size()) {
@@ -689,14 +711,10 @@ ASTPrimitive::compile(CEnv& cenv)
int
main()
{
- Module* module = new Module("interactive");
- ExecutionEngine* engine = ExecutionEngine::create(module);
- CEnv cenv(module, engine->getTargetData());
-
PEnv penv;
+ penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, 0));
penv.reg(penv.sym("if"), PEnv::Parser(parseIf, 0));
penv.reg(penv.sym("def"), PEnv::Parser(parseDef, 0));
- penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, 0));
penv.reg(penv.sym("+"), PEnv::Parser(parsePrim, Instruction::Add));
penv.reg(penv.sym("-"), PEnv::Parser(parsePrim, Instruction::Sub));
penv.reg(penv.sym("*"), PEnv::Parser(parsePrim, Instruction::Mul));
@@ -705,8 +723,13 @@ main()
penv.reg(penv.sym("&"), PEnv::Parser(parsePrim, Instruction::And));
penv.reg(penv.sym("|"), PEnv::Parser(parsePrim, Instruction::Or));
penv.reg(penv.sym("^"), PEnv::Parser(parsePrim, Instruction::Xor));
- cenv.def(penv.sym("true"), new ASTLiteral<bool>(true));
- cenv.def(penv.sym("false"), new ASTLiteral<bool>(false));
+
+ Module* module = new Module("repl");
+ ExecutionEngine* engine = ExecutionEngine::create(module);
+ CEnv cenv(penv, module, engine->getTargetData());
+
+ cenv.code.def(penv.sym("true"), new ASTLiteral<bool>(true));
+ cenv.code.def(penv.sym("false"), new ASTLiteral<bool>(false));
while (1) {
std::cout << "(=>) ";
@@ -716,21 +739,47 @@ main()
break;
try {
- AST* ast = parseExpression(penv, exp);
- lambdaLift(cenv, ast);
- ASTPrototype* proto = new ASTPrototype();
- ASTFunction* func = new ASTFunction(proto, ast);
- Function* code = func->compile(cenv, cenv.gensym("_repl"));
- void* fp = engine->getPointerToFunction(code);
- code->dump();
- double (*f)() = (double (*)())fp;
- std::cout << f() << " :: ";
- func->body->type(cenv)->print(std::cout);
- std::cout << endl;
+ AST* body = parseExpression(penv, exp);
+ ASTTuple* prot = new ASTTuple();
+ AType* bodyT = body->type(cenv);
+
+ if (!bodyT) throw TypeError("REPL call to untyped body");
+ if (bodyT->var) throw TypeError("REPL call to variable typed body");
+
+ body->lift(cenv);
+
+ if (bodyT->ctype) {
+ // Create anonymous function to insert code into.
+ Function* f = compileFunction(cenv, cenv.gensym("_repl"), *prot, bodyT->ctype);
+ BasicBlock* bb = BasicBlock::Create("entry", f);
+ cenv.builder.SetInsertPoint(bb);
+
+ try {
+ Value* retVal = body->compile(cenv);
+ cenv.builder.CreateRet(retVal); // Finish function
+ verifyFunction(*f); // Validate generated code
+ cenv.fpm.run(*f); // Optimize function
+ } catch (SyntaxError e) {
+ f->eraseFromParent(); // Error reading body, remove function
+ throw e;
+ }
+
+ void* fp = engine->getPointerToFunction(f);
+ double (*cfunc)() = (double (*)())fp;
+ std::cout << cfunc();
+
+ } else {
+ Value* val = body->compile(cenv);
+ std::cout << val;
+ }
+ std::cout << " :: " << body->type(cenv)->str(cenv) << endl;
+
} catch (SyntaxError e) {
std::cerr << "Syntax error: " << e.what() << endl;
} catch (TypeError e) {
std::cerr << "Type error: " << e.what() << endl;
+ } catch (CompileError e) {
+ std::cerr << "Compile error: " << e.what() << endl;
}
}