aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Robillard <d@drobilla.net>2009-01-24 20:46:08 +0000
committerDavid Robillard <d@drobilla.net>2009-01-24 20:46:08 +0000
commit9f6217e6d6b4ab24ddaa10b52c3b1ce6af2a8746 (patch)
treef75b45be2ce0e8de8aa9387db69d9d19c3339ff6
parent0a88fdb91fcf631ecae7f8b9a55970822536d9ef (diff)
downloadresp-9f6217e6d6b4ab24ddaa10b52c3b1ce6af2a8746.tar.gz
resp-9f6217e6d6b4ab24ddaa10b52c3b1ce6af2a8746.tar.bz2
resp-9f6217e6d6b4ab24ddaa10b52c3b1ce6af2a8746.zip
Half decent extensible parser design.
git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@10 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rwxr-xr-xbuild.sh2
-rw-r--r--ll.cpp431
2 files changed, 246 insertions, 187 deletions
diff --git a/build.sh b/build.sh
index d983233..bfa963b 100755
--- a/build.sh
+++ b/build.sh
@@ -1,5 +1,5 @@
#!/bin/sh
CXXFLAGS="-O0 -g -Wall -Wextra -Wno-unused-parameter"
-g++ $CXXFLAGS ll.cpp `llvm-config --cppflags --ldflags --libs core jit native` -lm -O3 -o ll
+g++ $CXXFLAGS ll.cpp `llvm-config --cppflags --ldflags --libs core jit native` -lm -o ll
diff --git a/ll.cpp b/ll.cpp
index 1e48295..0e87a2c 100644
--- a/ll.cpp
+++ b/ll.cpp
@@ -48,7 +48,7 @@ struct Error : public std::exception {
/***************************************************************************
- * S-Expression Lexer - Read text and output nested lists of strings *
+ * S-Expression Lexer :: text -> S-Expressions (SExp) *
***************************************************************************/
struct SyntaxError : public Error { SyntaxError(const char* m) : Error(m) {} };
@@ -117,6 +117,7 @@ readExpression(std::istream& in)
case 1: return stk.top(); break;
default: throw SyntaxError("Missing ')'");
}
+ return SExp();
}
@@ -133,17 +134,30 @@ struct CEnv; ///< Compile Time Environment
struct AST {
virtual ~AST() {}
virtual const Type* type(CEnv& cenv) const = 0;
- virtual Value* compile(CEnv& cenv) = 0;
+ virtual Value* compile(CEnv& cenv) = 0;
};
-/// Numeric literal, e.g. "1.0"
-struct ASTNumber : public AST {
- ASTNumber(double v) : val(v) {}
- virtual const Type* type(CEnv& cenv) const { return Type::DoubleTy; }
- virtual Value* compile(CEnv& cenv);
- const double val;
+/// Literal
+template<typename VT>
+struct ASTLiteral : public AST {
+ ASTLiteral(VT v) : val(v) {}
+ const Type* type(CEnv& cenv) const;
+ Value* compile(CEnv& cenv);
+ const VT val;
};
+#define LITERAL(CT, VT, COMPILED) \
+template<> const Type* \
+ASTLiteral<CT>::type(CEnv& cenv) const { return VT; } \
+ \
+template<> Value* \
+ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); }
+
+/// Literal template specialisations
+LITERAL(int32_t, Type::Int32Ty, ConstantInt::get(type(cenv), val, true));
+LITERAL(float, Type::FloatTy, ConstantFP::get(type(cenv), val));
+LITERAL(bool, Type::Int1Ty, ConstantInt::get(type(cenv), val, false));
+
/// Symbol, e.g. "a"
struct ASTSymbol : public AST {
ASTSymbol(const string& n) : name(n) {}
@@ -154,60 +168,60 @@ struct ASTSymbol : public AST {
/// Function call/application, e.g. "(func arg1 arg2)"
struct ASTCall : public AST {
- ASTCall(AST* f, vector<AST*>& a) : func(f), args(a) {}
+ ASTCall(const vector<AST*>& c) : code(c) {}
virtual const Type* type(CEnv& cenv) const {
+ AST* func = code[0];
const FunctionType* ftype = dynamic_cast<const FunctionType*>(func->type(cenv));
if (!ftype) throw TypeError(string("Call to non-function type :: ")
.append(func->type(cenv)->getDescription()).c_str());
return ftype->getReturnType();
}
virtual Value* compile(CEnv& cenv);
- AST* const func;
- const vector<AST*> args;
+ const vector<AST*> code;
};
/// Definition special form, e.g. "(def x 2)" or "(def (next y) (+ y 1))"
struct ASTDefinition : public ASTCall {
- ASTDefinition(vector<AST*> a) : ASTCall(NULL, a) {}
- virtual const Type* type(CEnv& cenv) const { return args[1]->type(cenv); }
+ ASTDefinition(const vector<AST*>& c) : ASTCall(c) {}
+ virtual const Type* type(CEnv& cenv) const { return code[2]->type(cenv); }
virtual Value* compile(CEnv& cenv);
};
/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct ASTIf : public ASTCall {
- ASTIf(vector<AST*>& a) : ASTCall(NULL, a) {}
+ ASTIf(const vector<AST*>& c) : ASTCall(c) {}
virtual const Type* type(CEnv& cenv) const {
- const Type* ctype = args[0]->type(cenv);
- const Type* ttype = args[1]->type(cenv);
- const Type* etype = args[2]->type(cenv);
- if (ctype != Type::DoubleTy) throw TypeError("If condition is not a number");
- if (ttype != etype) throw TypeError("If branches have different types");
- return ttype;
+ const Type* cT = code[1]->type(cenv);
+ const Type* tT = code[2]->type(cenv);
+ const Type* eT = code[3]->type(cenv);
+ if (cT != Type::Int1Ty) throw TypeError("If condition is not a boolean");
+ if (tT != eT) throw TypeError("If branches have different types");
+ return tT;
}
virtual Value* compile(CEnv& cenv);
};
/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct ASTPrimitive : public ASTCall {
- ASTPrimitive(char n, vector<AST*>& a) : ASTCall(NULL, a), name(n) {}
- virtual const Type* type(CEnv& cenv) const { return Type::DoubleTy; }
+ ASTPrimitive(const vector<AST*>& c, Instruction::BinaryOps o) : ASTCall(c), op(o) {}
+ virtual const Type* type(CEnv& cenv) const { return Type::FloatTy; }
virtual Value* compile(CEnv& cenv);
- const char name;
+ Instruction::BinaryOps op;
};
/// Function prototype (actual LLVM IR function prototype)
struct ASTPrototype {
- ASTPrototype(vector<ASTSymbol*> p=vector<ASTSymbol*>()) : params(p) {}
+ ASTPrototype(vector<AST*> p=vector<AST*>()) : params(p) {}
vector<const Type*> argsType(CEnv& cenv) {
vector<const Type*> types;
- FOREACH(vector<ASTSymbol*>::const_iterator, p, params)
+ FOREACH(vector<AST*>::const_iterator, p, params)
types.push_back((*p)->type(cenv));
return types;
}
virtual const Type* type(CEnv& cenv) const { return NULL; }
Function* compile(CEnv& cenv, FunctionType* type, const string& name);
string name;
- const vector<ASTSymbol*> params;
+ vector<AST*> params;
};
/// Closure (first-class function with captured lexical bindings)
@@ -218,13 +232,13 @@ struct ASTClosure : public AST {
}
virtual Value* compile(CEnv& cenv);
virtual void lift(CEnv& cenv);
- ASTPrototype* const prot;
- AST* const body;
+ ASTPrototype* const prot;
+ AST* const body;
+ vector<const ASTSymbol*> bindings;
private:
Function* func;
};
-
/// Function definition (actual LLVM IR function)
struct ASTFunction {
ASTFunction(ASTPrototype* p, AST* b) : prot(p), body(b) {}
@@ -236,124 +250,97 @@ struct ASTFunction {
/***************************************************************************
- * Parser - Transform S-Expressions into AST nodes *
+ * Parser - S-Expressions (SExp) -> AST Nodes (AST) *
***************************************************************************/
-/// Parse Time Environment
-/*struct PEnv : public list< map<string, ASTSymbol*> > {
- PEnv() : list< map<string, ASTSymbol*> >(1) {}
- ASTSymbol* lookup(const string& name) const {
- FOREACH(const_iterator, i, *this) {
- map<string, ASTSymbol*>::const_iterator s = i->find(name);
- if (s != i->end())
- return s->second;
- }
- return NULL;
+typedef unsigned UD; // User Data passed to registered parse functions
+
+// Parse Time Environment (just a symbol table)
+struct PEnv : private map<const string, ASTSymbol*> {
+ typedef AST* (*PF)(PEnv&, const list<SExp>&, UD); // Parse Function
+ struct Parser { Parser(PF f, UD d) : pf(f), ud(d) {} PF pf; UD ud; };
+ map<const ASTSymbol*, Parser> parsers;
+ void reg(const ASTSymbol* s, const Parser& p) {
+ parsers.insert(make_pair(s, p));
}
-};*/
-struct PEnv : private map<string, ASTSymbol*> {
- ASTSymbol* sym(const string& name) {
- const_iterator i = find(name);
- if (i != end()) {
- return i->second;
- } else {
- ASTSymbol* sym = new ASTSymbol(name);
- insert(make_pair(name, sym));
- return sym;
- }
+ const Parser* parser(const ASTSymbol* s) const {
+ map<const ASTSymbol*, Parser>::const_iterator i = parsers.find(s);
+ return (i != parsers.end()) ? &i->second : NULL;
+ }
+ ASTSymbol* sym(const string& s) {
+ const const_iterator i = find(s);
+ return ((i != end())
+ ? i->second
+ : insert(make_pair(s, new ASTSymbol(s))).first->second);
}
};
-static AST* parseExpression(PEnv& penv, const SExp& exp);
-
-/// numberexpr ::= number
+/// The fundamental parser method
static AST*
-parseNumber(PEnv& penv, const SExp& exp)
+parseExpression(PEnv& penv, const SExp& exp)
{
- assert(exp.type == SExp::ATOM);
- return new ASTNumber(strtod(exp.atom.c_str(), NULL));
-}
+ if (exp.type == SExp::LIST) {
+ // Parse head of list
+ if (exp.list.empty()) throw SyntaxError("Call to empty list");
+ vector<AST*> code(exp.list.size());
+ code[0] = parseExpression(penv, exp.list.front());
+
+ // Dispatch to parse function if possible
+ ASTSymbol* sym = dynamic_cast<ASTSymbol*>(code[0]);
+ const PEnv::Parser* handler = sym ? penv.parser(sym) : NULL;
+ if (handler)
+ return handler->pf(penv, exp.list, handler->ud);
+
+ // Parse as a regular call
+ list<SExp>::const_iterator i = exp.list.begin(); ++i;
+ for (size_t n = 1; i != exp.list.end(); ++i)
+ code[n++] = parseExpression(penv, *i);
+ return new ASTCall(code);
-/// identifierexpr ::= identifier
-static ASTSymbol*
-parseSymbol(PEnv& penv, const SExp& exp)
-{
- if (exp.type != SExp::ATOM) throw SyntaxError("Expected symbol");
+ } else if (isdigit(exp.atom[0])) {
+ if (exp.atom.find('.') == string::npos)
+ return new ASTLiteral<int32_t>(strtol(exp.atom.c_str(), NULL, 10));
+ else
+ return new ASTLiteral<float>(strtod(exp.atom.c_str(), NULL));
+ }
return penv.sym(exp.atom);
}
-/// prototypeexpr ::= ([arg*])
-static ASTPrototype*
-parsePrototype(PEnv& penv, const SExp& exp)
-{
- vector<ASTSymbol*> params;
- FOREACH(list<SExp>::const_iterator, i, exp.list)
- if (i->type == SExp::ATOM)
- params.push_back(penv.sym(i->atom));
- else
- throw SyntaxError("Expected symbol");
-
- return new ASTPrototype(params);
-}
+// Special forms
-/// callexpr ::= (expression [...])
-static AST*
-parseCall(PEnv& penv, const SExp& exp)
+static vector<AST*>
+pmap(PEnv& penv, const list<SExp>& l)
{
- const list<SExp>& l = exp.list;
- if (l.empty())
- return NULL;
-
list<SExp>::const_iterator i = l.begin();
+ vector<AST*> code(l.size());
+ for (size_t n = 0; i != l.end(); ++i)
+ code[n++] = parseExpression(penv, *i);
+ return code;
+}
- if (i->type == SExp::LIST)
- throw SyntaxError("Call to list");
+static AST*
+parseIf(PEnv& penv, const list<SExp>& c, UD)
+ { return new ASTIf(pmap(penv, c)); }
- const string& name = i++->atom;
+static AST*
+parseDef(PEnv& penv, const list<SExp>& c, UD)
+ { return new ASTDefinition(pmap(penv, c)); }
- if (name == "fn") { // Lambda special form
- ASTPrototype* proto = parsePrototype(penv, *i++);
- AST* body = parseExpression(penv, *i++);
- return new ASTClosure(proto, body);
-
- //} else if (name == "foreign") { // Foreign special form
- // return parsePrototype(penv, *i++);
- }
-
- vector<AST*> args(l.size() - 1);
- for (size_t n = 0; i != l.end(); ++i, ++n)
- args[n] = parseExpression(penv, *i);
-
- if (name == "def") { // Define special form
- return new ASTDefinition(args);
+static AST*
+parsePrim(PEnv& penv, const list<SExp>& c, UD data)
+ { return new ASTPrimitive(pmap(penv, c), (Instruction::BinaryOps)data); }
- } else if (name == "if") { // If special form
- return new ASTIf(args);
-
- } else { // Generic application
- if (name.length() == 1) {
- switch (name[0]) {
- case '+': case '-': case '*': case '/':
- case '%': case '&': case '|': case '^':
- return new ASTPrimitive(name[0], args);
- }
- }
- return new ASTCall(parseExpression(penv, name), args);
- }
-}
+static ASTPrototype*
+parsePrototype(PEnv& penv, const SExp& e, UD)
+ { return new ASTPrototype(pmap(penv, e.list)); }
static AST*
-parseExpression(PEnv& penv, const SExp& exp)
+parseFn(PEnv& penv, const list<SExp>& c, UD)
{
- if (exp.type == SExp::LIST) {
- return parseCall(penv, exp);
- } else if (isalpha(exp.atom[0])) {
- return parseSymbol(penv, exp);
- } else if (isdigit(exp.atom[0])) {
- return parseNumber(penv, exp);
- } else {
- throw SyntaxError("Illegal atom");
- }
+ list<SExp>::const_iterator a = c.begin(); ++a;
+ return new ASTClosure(
+ parsePrototype(penv, *a++, 0),
+ parseExpression(penv, *a++));
}
@@ -361,6 +348,33 @@ parseExpression(PEnv& penv, const SExp& exp)
* Code Generation *
***************************************************************************/
+/// Generic Recursive Environment (stack of key:value dictionaries)
+template<typename K, typename V>
+struct Env : public list< map<K,V> > {
+ Env() : list< map<K, V> >(1) {}
+ void push() { this->push_front(map<K,V>()); }
+ void push(const map<K,V>& frame) { this->push_front(frame); }
+ map<K,V>& pop() {
+ map<K,V>& front = this->front();
+ this->pop_front();
+ return front;
+ }
+ void def(const K& k, const V& v) {
+ if (this->front().find(k) != this->front().end())
+ std::cerr << "WARNING: Redefinition: " << k << endl;
+ this->front()[k] = v;
+ }
+ V* ref(const K& name) {
+ typename Env::iterator i = this->begin();
+ for (; i != this->end(); ++i) {
+ typename map<K,V>::iterator s = i->find(name);
+ if (s != i->end())
+ return &s->second;
+ }
+ return 0;
+ }
+};
+
/// Compile-time environment
struct CEnv {
CEnv(Module* m, const TargetData* target)
@@ -381,15 +395,22 @@ struct CEnv {
string gensym(const char* base="_") {
ostringstream s; s << base << symID++; return s.str();
}
- typedef map<const ASTSymbol*, Value*> Vals;
- typedef map<const ASTSymbol*, const Type*> Types;
+ void def(const ASTSymbol* sym, AST* expr) {
+ types.def(sym, expr->type(*this));
+ code.def(sym, expr);
+ }
+ typedef Env<const ASTSymbol*, const Type*> Types;
+ typedef Env<const ASTSymbol*, AST*> Code;
+ typedef Env<const ASTSymbol*, Value*> Vals;
+
IRBuilder<> builder;
Module* module;
ExistingModuleProvider provider;
FunctionPassManager fpm;
size_t symID;
- Vals vals;
Types types;
+ Code code;
+ Vals vals;
};
static void
@@ -399,62 +420,71 @@ lambdaLift(CEnv& env, AST* ast)
lambdaLift(env, closure->body);
closure->lift(env);
} else if (ASTCall* call = dynamic_cast<ASTCall*>(ast)) {
- FOREACH(vector<AST*>::const_iterator, a, call->args)
+ FOREACH(vector<AST*>::const_iterator, a, call->code)
lambdaLift(env, *a);
}
}
-Value*
-ASTNumber::compile(CEnv& cenv)
-{
- return ConstantFP::get(APFloat(val));
-}
-
const Type*
ASTSymbol::type(CEnv& cenv) const
{
- CEnv::Types::const_iterator t = cenv.types.find(this);
- if (t != cenv.types.end()) {
- return t->second;
+ const Type** t = cenv.types.ref(this);
+ if (t) {
+ return *t;
} else {
//std::cerr << "WARNING: Untyped symbol: " << name << endl;
- return Type::DoubleTy;
+ return Type::FloatTy;
}
}
Value*
ASTSymbol::compile(CEnv& cenv)
{
- CEnv::Vals::const_iterator v = cenv.vals.find(this);
- if (v == cenv.vals.end())
- throw SyntaxError((string("Undefined symbol '") + name + "'").c_str());
- return v->second;
+ Value*const* v = cenv.vals.ref(this);
+ if (v)
+ return *v;
+
+ AST*const* c = cenv.code.ref(this);
+ if (c) {
+ Value* v = (*c)->compile(cenv);
+ cenv.vals.def(this, v);
+ return v;
+ }
+
+ throw SyntaxError((string("Undefined symbol '") + name + "'").c_str());
}
Value*
ASTDefinition::compile(CEnv& cenv)
{
- if (args.size() != 2) throw SyntaxError("\"def\" takes exactly 2 arguments");
- ASTSymbol* sym = dynamic_cast<ASTSymbol*>(args[0]);
+ if (code.size() != 3) throw SyntaxError("\"def\" takes exactly 2 arguments");
+ const ASTSymbol* sym = dynamic_cast<const ASTSymbol*>(code[1]);
if (!sym) throw SyntaxError("Definition name is not a symbol");
- CEnv::Vals::const_iterator v = cenv.vals.find(sym);
- if (v != cenv.vals.end()) throw SyntaxError("Symbol redefinition");
- Value* val = args[1]->compile(cenv);
- cenv.types[sym] = args[1]->type(cenv);
- cenv.vals[sym] = val;
+ Value* val = code[2]->compile(cenv);
+ cenv.types.def(sym, code[2]->type(cenv));
+ cenv.code.def(sym, code[2]);
+ cenv.vals.def(sym, val);
return val;
}
Value*
ASTCall::compile(CEnv& cenv)
{
+ AST* func = code[0];
+ AST** closure = cenv.code.ref((ASTSymbol*)func);
+ assert(closure);
+ ASTClosure* c = dynamic_cast<ASTClosure*>(*closure);
+ assert(c);
Function* f = dynamic_cast<Function*>(func->compile(cenv));
if (!f) throw SyntaxError("Call to non-function");
vector<Value*> params;
- for (size_t i = 0; i != args.size(); ++i)
- params.push_back(args[i]->compile(cenv));
+ for (size_t i = 1; i < code.size(); ++i)
+ params.push_back(code[i]->compile(cenv));
+
+ for (size_t i = 0; i < c->bindings.size(); ++i)
+ std::cout << "BINDING: " << c->bindings[i]->name << endl;
return cenv.builder.CreateCall(f, params.begin(), params.end(), "calltmp");
}
@@ -462,25 +492,20 @@ ASTCall::compile(CEnv& cenv)
Value*
ASTIf::compile(CEnv& cenv)
{
- Value* condV = args[0]->compile(cenv);
-
- // Convert condition to a bool by comparing equal to 0.0.
- condV = cenv.builder.CreateFCmpONE(
- condV, ConstantFP::get(APFloat(0.0)), "ifcond");
-
+ Value* condV = code[1]->compile(cenv);
Function* parent = cenv.builder.GetInsertBlock()->getParent();
// Create blocks for the then and else cases.
// Insert the 'then' block at the end of the function.
- BasicBlock* thenBB = BasicBlock::Create("then", parent);
- BasicBlock* elseBB = BasicBlock::Create("else");
+ BasicBlock* thenBB = BasicBlock::Create("then", parent);
+ BasicBlock* elseBB = BasicBlock::Create("else");
BasicBlock* mergeBB = BasicBlock::Create("ifcont");
cenv.builder.CreateCondBr(condV, thenBB, elseBB);
// Emit then value.
cenv.builder.SetInsertPoint(thenBB);
- Value* thenV = args[1]->compile(cenv);
+ Value* thenV = code[2]->compile(cenv);
cenv.builder.CreateBr(mergeBB);
// compile of 'Then' can change the current block, update thenBB
@@ -489,7 +514,7 @@ ASTIf::compile(CEnv& cenv)
// Emit else block.
parent->getBasicBlockList().push_back(elseBB);
cenv.builder.SetInsertPoint(elseBB);
- Value* elseV = args[2]->compile(cenv);
+ Value* elseV = code[3]->compile(cenv);
cenv.builder.CreateBr(mergeBB);
// compile of 'Else' can change the current block, update elseBB
@@ -509,6 +534,34 @@ void
ASTClosure::lift(CEnv& cenv)
{
assert(!func);
+
+ //set<const ASTSymbol*> unbound;
+ const ASTCall* call = dynamic_cast<const ASTCall*>(body);
+ if (call) {
+ std::cout << "LIFT CALL BODY\n";
+ }
+
+#if 0
+ Env<const ASTSymbol*, const AST*> paramsEnv;
+ for (vector<const ASTSymbol*>::const_iterator p = prot->params.begin();
+ p != prot->params.end(); ++p) {
+ //paramsEnv[*p] = NULL;
+ std::cout << "PARAM: " << (*p)->name << endl;
+ }
+#endif
+
+ cenv.code.push();
+
+ ASTSymbol* sym = dynamic_cast<ASTSymbol*>(body);
+ if (sym) {
+ AST** obj = cenv.code.ref(sym);
+ if (!obj) {
+ std::cout << "UNDEFINED SYMBOL BODY\n";
+ prot->params.push_back(sym);
+ bindings.push_back(sym);
+ }
+ }
+
// Write function declaration
Function* f = prot->compile(cenv,
FunctionType::get(body->type(cenv), prot->argsType(cenv), false),
@@ -518,9 +571,9 @@ ASTClosure::lift(CEnv& cenv)
// Bind argument values in CEnv
vector<Value*> args;
- vector<ASTSymbol*>::const_iterator p = prot->params.begin();
+ vector<AST*>::const_iterator p = prot->params.begin();
for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
- cenv.vals[*p] = &*a;
+ cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a);
// Write function body
try {
@@ -533,6 +586,8 @@ ASTClosure::lift(CEnv& cenv)
f->eraseFromParent(); // Error reading body, remove function
throw e;
}
+
+ cenv.code.pop();
}
Value*
@@ -565,11 +620,12 @@ ASTPrototype::compile(CEnv& cenv, FunctionType* FT, const std::string& n)
throw SyntaxError("Function redefined with mismatched arguments");
}
+ // Set argument names in generated code
Function::arg_iterator a = f->arg_begin();
for (size_t i = 0; i != params.size(); ++a, ++i) {
assert(params[i]);
- a->setName(params[i]->name); // Set name in generated code
- //cenv.vals[params[i]] = a; // Add to environment
+ ASTSymbol* sym = dynamic_cast<ASTSymbol*>(params[i]);
+ a->setName(sym ? sym->name : cenv.gensym("_a"));
}
return f;
@@ -605,22 +661,11 @@ ASTFunction::compile(CEnv& cenv, const string& name)
Value*
ASTPrimitive::compile(CEnv& cenv)
{
- Instruction::BinaryOps op;
- switch (name) {
- case '+': op = Instruction::Add; break;
- case '-': op = Instruction::Sub; break;
- case '*': op = Instruction::Mul; break;
- case '/': op = Instruction::FDiv; break;
- case '%': op = Instruction::FRem; break;
- case '&': op = Instruction::And; break;
- case '|': op = Instruction::Or; break;
- case '^': op = Instruction::Xor; break;
- default: throw SyntaxError("Unknown primitive");
- }
-
- vector<Value*> params;
- FOREACH(vector<AST*>::const_iterator, a, args)
- params.push_back((*a)->compile(cenv));
+ size_t np = 0;
+ vector<Value*> params(code.size() - 1);
+ vector<AST*>::const_iterator a = code.begin();
+ for (++a; a != code.end(); ++a)
+ params[np++] = (*a)->compile(cenv);
switch (params.size()) {
case 0:
@@ -644,9 +689,24 @@ ASTPrimitive::compile(CEnv& cenv)
static void
repl(CEnv& cenv, ExecutionEngine* engine)
{
+ // Set up our language
PEnv penv;
+ penv.reg(penv.sym("if"), PEnv::Parser(parseIf, 0));
+ penv.reg(penv.sym("def"), PEnv::Parser(parseDef, 0));
+ penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, 0));
+ penv.reg(penv.sym("+"), PEnv::Parser(parsePrim, Instruction::Add));
+ penv.reg(penv.sym("-"), PEnv::Parser(parsePrim, Instruction::Sub));
+ penv.reg(penv.sym("*"), PEnv::Parser(parsePrim, Instruction::Mul));
+ penv.reg(penv.sym("/"), PEnv::Parser(parsePrim, Instruction::FDiv));
+ penv.reg(penv.sym("%"), PEnv::Parser(parsePrim, Instruction::FRem));
+ penv.reg(penv.sym("&"), PEnv::Parser(parsePrim, Instruction::And));
+ penv.reg(penv.sym("|"), PEnv::Parser(parsePrim, Instruction::Or));
+ penv.reg(penv.sym("^"), PEnv::Parser(parsePrim, Instruction::Xor));
+ cenv.def(penv.sym("true"), new ASTLiteral<bool>(true));
+ cenv.def(penv.sym("false"), new ASTLiteral<bool>(false));
+
while (1) {
- std::cout << "> ";
+ std::cout << "(=>) ";
std::cout.flush();
SExp exp = readExpression(std::cin);
if (exp.type == SExp::LIST && exp.list.empty())
@@ -654,12 +714,12 @@ repl(CEnv& cenv, ExecutionEngine* engine)
try {
AST* ast = parseExpression(penv, exp);
- //const Type* type = ast->type(cenv);
lambdaLift(cenv, ast);
ASTPrototype* proto = new ASTPrototype();
ASTFunction* func = new ASTFunction(proto, ast);
- Function* code = func->compile(cenv, cenv.gensym("repl"));
+ Function* code = func->compile(cenv, cenv.gensym("_repl"));
void* fp = engine->getPointerToFunction(code);
+ code->dump();
double (*f)() = (double (*)())fp;
std::cout << f() << " :: ";
func->body->type(cenv)->print(std::cout);
@@ -690,4 +750,3 @@ main()
module->dump();
return 0;
}
-