aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Robillard <d@drobilla.net>2009-03-06 22:12:36 +0000
committerDavid Robillard <d@drobilla.net>2009-03-06 22:12:36 +0000
commitecef8f697c66e15b85beb934d2b617b915a97aab (patch)
tree576af3f3063f07c08ecbbd1cfdb9d8034cc5bc2b
parent382d3051052fd20ab55f40beb7664bfb3f0379a1 (diff)
downloadresp-ecef8f697c66e15b85beb934d2b617b915a97aab.tar.gz
resp-ecef8f697c66e15b85beb934d2b617b915a97aab.tar.bz2
resp-ecef8f697c66e15b85beb934d2b617b915a97aab.zip
Cleanup and de-llvm-ify primitive stuff.
Fix type inference (only treat actual type expressions as type expressions). git-svn-id: http://svn.drobilla.net/resp/tuplr@64 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r--llvm.cpp121
-rw-r--r--tuplr.cpp43
-rw-r--r--tuplr.hpp27
-rw-r--r--typing.cpp50
4 files changed, 139 insertions, 102 deletions
diff --git a/llvm.cpp b/llvm.cpp
index 2e53b7f..03c49f9 100644
--- a/llvm.cpp
+++ b/llvm.cpp
@@ -46,12 +46,6 @@ struct CEngine {
IRBuilder<> builder;
};
-struct CArg {
- CArg(int o=0, int a=0) : op(o), arg(a) {}
- int op;
- int arg;
-};
-
/***************************************************************************
* Typing *
@@ -72,29 +66,6 @@ AType::ctype()
}
}
-#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End)
-
-void
-ASTPrimitive::constrain(TEnv& tenv) const
-{
- FOREACH(const_iterator, p, *this)
- (*p)->constrain(tenv);
- if (OP_IS_A(arg->op, Instruction::BinaryOps)) {
- if (size() <= 2) throw Error((format("`%1%' requires at least 2 arguments")
- % at(0)->str()).str(), exp.loc);
- AType* tvar = tenv.type(this);
- for (size_t i = 1; i < size(); ++i)
- tenv.constrain(at(i), tvar);
- } else if (arg->op == Instruction::ICmp) {
- if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments")
- % at(0)->str()).str(), exp.loc);
- tenv.constrain(at(1), tenv.type(at(2)));
- tenv.constrain(this, tenv.named("Bool"));
- } else {
- throw Error((format("unknown primitive `%1%'") % at(0)->str()).str(), exp.loc);
- }
-}
-
/***************************************************************************
* Code Generation *
@@ -370,35 +341,43 @@ ASTIf::compile(CEnv& cenv)
CValue
ASTPrimitive::compile(CEnv& cenv)
{
- Value* a = LLVal(cenv.compile(at(1)));
- Value* b = LLVal(cenv.compile(at(2)));
-
- if (OP_IS_A(arg->op, Instruction::BinaryOps)) {
- const Instruction::BinaryOps bo = (Instruction::BinaryOps)arg->op;
- if (size() == 2)
- return cenv.compile(at(1));
- Value* val = cenv.engine.builder.CreateBinOp(bo, a, b);
+ Value* a = LLVal(cenv.compile(at(1)));
+ Value* b = LLVal(cenv.compile(at(2)));
+ bool isInt = cenv.tenv.type(at(1))->str() == "Int";
+ const string n = dynamic_cast<ASTSymbol*>(at(0))->str();
+
+ // Binary arithmetic operations
+ Instruction::BinaryOps op = (Instruction::BinaryOps)0;
+ if (n == "+") op = Instruction::Add;
+ if (n == "-") op = Instruction::Sub;
+ if (n == "*") op = Instruction::Mul;
+ if (n == "&") op = Instruction::And;
+ if (n == "|") op = Instruction::Or;
+ if (n == "^") op = Instruction::Xor;
+ if (n == "/") op = isInt ? Instruction::SDiv : Instruction::FDiv;
+ if (n == "%") op = isInt ? Instruction::SRem : Instruction::FRem;
+ if (op != 0) {
+ Value* val = cenv.engine.builder.CreateBinOp(op, a, b);
for (size_t i = 3; i < size(); ++i)
- val = cenv.engine.builder.CreateBinOp(bo, val, LLVal(cenv.compile(at(i))));
+ val = cenv.engine.builder.CreateBinOp(op, val, LLVal(cenv.compile(at(i))));
return val;
- } else if (arg->op == Instruction::ICmp) {
- bool isInt = cenv.tenv.type(at(1))->str() == "Int";
- if (isInt) {
- return cenv.engine.builder.CreateICmp((CmpInst::Predicate)arg->arg, a, b);
- } else {
- // Translate to floating point operation
- switch (arg->arg) {
- case CmpInst::ICMP_EQ: arg->arg = CmpInst::FCMP_OEQ; break;
- case CmpInst::ICMP_NE: arg->arg = CmpInst::FCMP_ONE; break;
- case CmpInst::ICMP_SGT: arg->arg = CmpInst::FCMP_OGT; break;
- case CmpInst::ICMP_SGE: arg->arg = CmpInst::FCMP_OGE; break;
- case CmpInst::ICMP_SLT: arg->arg = CmpInst::FCMP_OLT; break;
- case CmpInst::ICMP_SLE: arg->arg = CmpInst::FCMP_OLE; break;
- default: throw Error("Unknown primitive", exp.loc);
- }
- return cenv.engine.builder.CreateFCmp((CmpInst::Predicate)arg->arg, a, b);
- }
}
+
+ // Comparison operations
+ CmpInst::Predicate pred = (CmpInst::Predicate)0;
+ if (n == "=") pred = isInt ? CmpInst::ICMP_EQ : CmpInst::FCMP_OEQ;
+ if (n == "!=") pred = isInt ? CmpInst::ICMP_NE : CmpInst::FCMP_ONE;
+ if (n == ">") pred = isInt ? CmpInst::ICMP_SGT : CmpInst::FCMP_OGT;
+ if (n == ">=") pred = isInt ? CmpInst::ICMP_SGE : CmpInst::FCMP_OGE;
+ if (n == "<") pred = isInt ? CmpInst::ICMP_SLT : CmpInst::FCMP_OLT;
+ if (n == "<=") pred = isInt ? CmpInst::ICMP_SLE : CmpInst::FCMP_OLE;
+ if (pred != 0) {
+ if (isInt)
+ return cenv.engine.builder.CreateICmp(pred, a, b);
+ else
+ return cenv.engine.builder.CreateFCmp(pred, a, b);
+ }
+
throw Error("Unknown primitive", exp.loc);
}
@@ -496,38 +475,8 @@ ASTCdrCall::compile(CEnv& cenv)
***************************************************************************/
void
-initLang(PEnv& penv, TEnv& tenv)
+initTypes(PEnv& penv, TEnv& tenv)
{
- penv.reg(true, "fn", PEnv::Handler(parseFn));
- penv.reg(true, "if", PEnv::Handler(parseCall<ASTIf>));
- penv.reg(true, "def", PEnv::Handler(parseCall<ASTDefinition>));
- penv.reg(true, "cons", PEnv::Handler(parseCall<ASTConsCall>));
- penv.reg(true, "car", PEnv::Handler(parseCall<ASTCarCall>));
- penv.reg(true, "cdr", PEnv::Handler(parseCall<ASTCdrCall>));
-
- bool trueVal = true;
- bool falseVal = false;
- penv.reg(false, "true", PEnv::Handler(parseLiteral<bool>, (CArg*)&trueVal));
- penv.reg(false, "false", PEnv::Handler(parseLiteral<bool>, (CArg*)&falseVal));
-
- map<string, CArg>* prims = new map<string, CArg>();
- prims->insert(make_pair("+", CArg(Instruction::Add)));
- prims->insert(make_pair("-", CArg(Instruction::Sub)));
- prims->insert(make_pair("*", CArg(Instruction::Mul)));
- prims->insert(make_pair("/", CArg(Instruction::FDiv)));
- prims->insert(make_pair("%", CArg(Instruction::FRem)));
- prims->insert(make_pair("&", CArg(Instruction::And)));
- prims->insert(make_pair("|", CArg(Instruction::Or)));
- prims->insert(make_pair("^", CArg(Instruction::Xor)));
- prims->insert(make_pair("=", CArg(Instruction::ICmp, CmpInst::ICMP_EQ)));
- prims->insert(make_pair("!=", CArg(Instruction::ICmp, CmpInst::ICMP_NE)));
- prims->insert(make_pair(">", CArg(Instruction::ICmp, CmpInst::ICMP_SGT)));
- prims->insert(make_pair(">=", CArg(Instruction::ICmp, CmpInst::ICMP_SGE)));
- prims->insert(make_pair("<", CArg(Instruction::ICmp, CmpInst::ICMP_SLT)));
- prims->insert(make_pair("<=", CArg(Instruction::ICmp, CmpInst::ICMP_SLE)));
- for (map<string,CArg>::iterator p = prims->begin(); p != prims->end(); ++p)
- penv.reg(true, p->first, PEnv::Handler(parseCall<ASTPrimitive>, &p->second));
-
tenv.def(penv.sym("Bool"), new AType(penv.sym("Bool"), Type::Int1Ty));
tenv.def(penv.sym("Int"), new AType(penv.sym("Int"), Type::Int32Ty));
tenv.def(penv.sym("Float"), new AType(penv.sym("Float"), Type::FloatTy));
diff --git a/tuplr.cpp b/tuplr.cpp
index 8e18d5f..580416c 100644
--- a/tuplr.cpp
+++ b/tuplr.cpp
@@ -31,6 +31,7 @@ Funcs ASTConsCall::funcs;
std::ostream& err = std::cerr;
std::ostream& out = std::cout;
+
/***************************************************************************
* S-Expression Lexer :: text -> S-Expressions (SExp) *
***************************************************************************/
@@ -93,6 +94,45 @@ readExpression(Cursor& cur, std::istream& in)
/***************************************************************************
+ * Standard Definitions *
+ ***************************************************************************/
+
+void
+initLang(PEnv& penv, TEnv& tenv)
+{
+ // Literals
+ static bool trueVal = true;
+ static bool falseVal = false;
+ penv.reg(false, "#t", PEnv::Handler(parseLiteral<bool>, &trueVal));
+ penv.reg(false, "#f", PEnv::Handler(parseLiteral<bool>, &falseVal));
+
+ // Special forms
+ penv.reg(true, "fn", PEnv::Handler(parseFn));
+ penv.reg(true, "if", PEnv::Handler(parseCall<ASTIf>));
+ penv.reg(true, "def", PEnv::Handler(parseCall<ASTDefinition>));
+ penv.reg(true, "cons", PEnv::Handler(parseCall<ASTConsCall>));
+ penv.reg(true, "car", PEnv::Handler(parseCall<ASTCarCall>));
+ penv.reg(true, "cdr", PEnv::Handler(parseCall<ASTCdrCall>));
+
+ // Numeric primitives
+ penv.reg(true, "+", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "-", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "*", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "/", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "%", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "&", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "|", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "^", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "=", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "!=", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, ">", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, ">=", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "<", PEnv::Handler(parseCall<ASTPrimitive>));
+ penv.reg(true, "<=", PEnv::Handler(parseCall<ASTPrimitive>));
+}
+
+
+/***************************************************************************
* EVAL/REPL/MAIN *
***************************************************************************/
@@ -115,6 +155,7 @@ main(int argc, char** argv)
{
PEnv penv;
TEnv tenv(penv);
+ initTypes(penv, tenv);
initLang(penv, tenv);
CEnv* cenv = newCenv(penv, tenv);
@@ -154,7 +195,7 @@ main(int argc, char** argv)
is.close();
}
- if (files.empty() || args.find("-r") != args.end())
+ if (args.find("-r") != args.end() || (files.empty() && args.find("-e") == args.end()))
ret = repl(*cenv);
a = args.find("-o");
diff --git a/tuplr.hpp b/tuplr.hpp
index 7c28112..0bfb5c1 100644
--- a/tuplr.hpp
+++ b/tuplr.hpp
@@ -31,7 +31,6 @@ typedef const void* CType; ///< Compiled type (opaque)
typedef void* CFunction; ///< Compiled function (opaque)
struct CEngine; ///< Backend data (opaque)
-struct CArg; ///< Parser function argument (opaque)
#define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i)
@@ -247,7 +246,7 @@ struct ASTCall : public ASTTuple {
/// Definition special form, e.g. "(def x 2)"
struct ASTDefinition : public ASTCall {
- ASTDefinition(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
+ ASTDefinition(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
void constrain(TEnv& tenv) const;
void lift(CEnv& cenv);
CValue compile(CEnv& cenv);
@@ -255,22 +254,21 @@ struct ASTDefinition : public ASTCall {
/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct ASTIf : public ASTCall {
- ASTIf(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
+ ASTIf(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
void constrain(TEnv& tenv) const;
CValue compile(CEnv& cenv);
};
/// Primitive (builtin arithmetic function), e.g. "(+ 2 3)"
struct ASTPrimitive : public ASTCall {
- ASTPrimitive(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t), arg(ca) {}
+ ASTPrimitive(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
void constrain(TEnv& tenv) const;
CValue compile(CEnv& cenv);
- CArg* arg;
};
/// Cons special form, e.g. "(cons 1 2)"
struct ASTConsCall : public ASTCall {
- ASTConsCall(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
+ ASTConsCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
AType* functionType(CEnv& cenv);
void constrain(TEnv& tenv) const;
void lift(CEnv& cenv);
@@ -280,14 +278,14 @@ struct ASTConsCall : public ASTCall {
/// Car special form, e.g. "(car p)"
struct ASTCarCall : public ASTCall {
- ASTCarCall(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
+ ASTCarCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
void constrain(TEnv& tenv) const;
CValue compile(CEnv& cenv);
};
/// Cdr special form, e.g. "(cdr p)"
struct ASTCdrCall : public ASTCall {
- ASTCdrCall(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {}
+ ASTCdrCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {}
void constrain(TEnv& tenv) const;
CValue compile(CEnv& cenv);
};
@@ -299,8 +297,8 @@ struct ASTCdrCall : public ASTCall {
// Parse Time Environment (symbol table)
struct PEnv : private map<const string, ASTSymbol*> {
- typedef AST* (*PF)(PEnv&, const SExp&, CArg*); // Parse Function
- struct Handler { Handler(PF f, CArg* a=0) : func(f), arg(a) {} PF func; CArg* arg; };
+ typedef AST* (*PF)(PEnv&, const SExp&, void*); // Parse Function
+ struct Handler { Handler(PF f, void* a=0) : func(f), arg(a) {} PF func; void* arg; };
map<const string, Handler> aHandlers; ///< Atom parse functions
map<const string, Handler> lHandlers; ///< List parse functions
void reg(bool list, const string& s, const Handler& h) {
@@ -358,20 +356,20 @@ parseExpression(PEnv& penv, const SExp& exp)
template<typename C>
inline AST*
-parseCall(PEnv& penv, const SExp& exp, CArg* arg)
+parseCall(PEnv& penv, const SExp& exp, void* arg)
{
- return new C(exp, pmap(penv, exp.list), arg);
+ return new C(exp, pmap(penv, exp.list));
}
template<typename T>
inline AST*
-parseLiteral(PEnv& penv, const SExp& exp, CArg* arg)
+parseLiteral(PEnv& penv, const SExp& exp, void* arg)
{
return new ASTLiteral<T>(*reinterpret_cast<T*>(arg));
}
inline AST*
-parseFn(PEnv& penv, const SExp& exp, CArg* arg)
+parseFn(PEnv& penv, const SExp& exp, void* arg)
{
SExp::List::const_iterator a = exp.list.begin(); ++a;
return new ASTClosure(
@@ -479,6 +477,7 @@ private:
* EVAL/REPL/MAIN *
***************************************************************************/
+void initTypes(PEnv& penv, TEnv& tenv);
void initLang(PEnv& penv, TEnv& tenv);
CEnv* newCenv(PEnv& penv, TEnv& tenv);
int eval(CEnv& cenv, const string& name, istream& is);
diff --git a/typing.cpp b/typing.cpp
index 9dc94f6..071e582 100644
--- a/typing.cpp
+++ b/typing.cpp
@@ -83,6 +83,54 @@ ASTIf::constrain(TEnv& tenv) const
}
void
+ASTPrimitive::constrain(TEnv& tenv) const
+{
+ const string n = dynamic_cast<ASTSymbol*>(at(0))->str();
+ enum { ARITHMETIC, BINARY, BITWISE, COMPARISON } type;
+ if (n == "+" || n == "-" || n == "*" || n == "/")
+ type = ARITHMETIC;
+ else if (n == "%")
+ type = BINARY;
+ else if (n == "&" || n == "|" || n == "^")
+ type = BITWISE;
+ else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=")
+ type = COMPARISON;
+ else
+ throw Error((format("unknown primitive `%1%'") % n).str(), exp.loc);
+
+ FOREACH(const_iterator, p, *this)
+ (*p)->constrain(tenv);
+
+ switch (type) {
+ case ARITHMETIC:
+ if (size() < 3)
+ throw Error((format("`%1%' requires at least 2 arguments") % n).str(), exp.loc);
+ for (size_t i = 1; i < size(); ++i)
+ tenv.constrain(this, tenv.type(at(i)));
+ break;
+ case BINARY:
+ if (size() != 3)
+ throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc);
+ tenv.constrain(this, tenv.type(at(1)));
+ tenv.constrain(this, tenv.type(at(2)));
+ break;
+ case BITWISE:
+ if (size() != 3)
+ throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc);
+ tenv.constrain(this, tenv.named("Bool"));
+ tenv.constrain(at(1), tenv.named("Bool"));
+ tenv.constrain(at(2), tenv.named("Bool"));
+ break;
+ case COMPARISON:
+ if (size() != 3)
+ throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc);
+ tenv.constrain(this, tenv.named("Bool"));
+ tenv.constrain(at(1), tenv.type(at(2)));
+ break;
+ }
+}
+
+void
ASTConsCall::constrain(TEnv& tenv) const
{
AType* t = new AType(ASTTuple(tenv.penv.sym("Pair"), 0));
@@ -174,7 +222,7 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4
} else if (t->var() && !s->contains(t)) {
substConstraints(cp, t, s);
return compose(unify(cp), TSubst(t, s));
- } else if (s->size() == t->size()) {
+ } else if (s->kind == AType::EXPR && s->kind == t->kind && s->size() == t->size()) {
for (size_t i = 0; i < s->size(); ++i) {
AType* si = dynamic_cast<AType*>(s->at(i));
AType* ti = dynamic_cast<AType*>(t->at(i));