diff options
author | David Robillard <d@drobilla.net> | 2009-03-06 22:12:36 +0000 |
---|---|---|
committer | David Robillard <d@drobilla.net> | 2009-03-06 22:12:36 +0000 |
commit | ecef8f697c66e15b85beb934d2b617b915a97aab (patch) | |
tree | 576af3f3063f07c08ecbbd1cfdb9d8034cc5bc2b | |
parent | 382d3051052fd20ab55f40beb7664bfb3f0379a1 (diff) | |
download | resp-ecef8f697c66e15b85beb934d2b617b915a97aab.tar.gz resp-ecef8f697c66e15b85beb934d2b617b915a97aab.tar.bz2 resp-ecef8f697c66e15b85beb934d2b617b915a97aab.zip |
Cleanup and de-llvm-ify primitive stuff.
Fix type inference (only treat actual type expressions as type expressions).
git-svn-id: http://svn.drobilla.net/resp/tuplr@64 ad02d1e2-f140-0410-9f75-f8b11f17cedd
-rw-r--r-- | llvm.cpp | 121 | ||||
-rw-r--r-- | tuplr.cpp | 43 | ||||
-rw-r--r-- | tuplr.hpp | 27 | ||||
-rw-r--r-- | typing.cpp | 50 |
4 files changed, 139 insertions, 102 deletions
@@ -46,12 +46,6 @@ struct CEngine { IRBuilder<> builder; }; -struct CArg { - CArg(int o=0, int a=0) : op(o), arg(a) {} - int op; - int arg; -}; - /*************************************************************************** * Typing * @@ -72,29 +66,6 @@ AType::ctype() } } -#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) - -void -ASTPrimitive::constrain(TEnv& tenv) const -{ - FOREACH(const_iterator, p, *this) - (*p)->constrain(tenv); - if (OP_IS_A(arg->op, Instruction::BinaryOps)) { - if (size() <= 2) throw Error((format("`%1%' requires at least 2 arguments") - % at(0)->str()).str(), exp.loc); - AType* tvar = tenv.type(this); - for (size_t i = 1; i < size(); ++i) - tenv.constrain(at(i), tvar); - } else if (arg->op == Instruction::ICmp) { - if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments") - % at(0)->str()).str(), exp.loc); - tenv.constrain(at(1), tenv.type(at(2))); - tenv.constrain(this, tenv.named("Bool")); - } else { - throw Error((format("unknown primitive `%1%'") % at(0)->str()).str(), exp.loc); - } -} - /*************************************************************************** * Code Generation * @@ -370,35 +341,43 @@ ASTIf::compile(CEnv& cenv) CValue ASTPrimitive::compile(CEnv& cenv) { - Value* a = LLVal(cenv.compile(at(1))); - Value* b = LLVal(cenv.compile(at(2))); - - if (OP_IS_A(arg->op, Instruction::BinaryOps)) { - const Instruction::BinaryOps bo = (Instruction::BinaryOps)arg->op; - if (size() == 2) - return cenv.compile(at(1)); - Value* val = cenv.engine.builder.CreateBinOp(bo, a, b); + Value* a = LLVal(cenv.compile(at(1))); + Value* b = LLVal(cenv.compile(at(2))); + bool isInt = cenv.tenv.type(at(1))->str() == "Int"; + const string n = dynamic_cast<ASTSymbol*>(at(0))->str(); + + // Binary arithmetic operations + Instruction::BinaryOps op = (Instruction::BinaryOps)0; + if (n == "+") op = Instruction::Add; + if (n == "-") op = Instruction::Sub; + if (n == "*") op = Instruction::Mul; + if (n == "&") op = Instruction::And; + if (n == "|") op = Instruction::Or; + if (n == "^") op = Instruction::Xor; + if (n == "/") op = isInt ? Instruction::SDiv : Instruction::FDiv; + if (n == "%") op = isInt ? Instruction::SRem : Instruction::FRem; + if (op != 0) { + Value* val = cenv.engine.builder.CreateBinOp(op, a, b); for (size_t i = 3; i < size(); ++i) - val = cenv.engine.builder.CreateBinOp(bo, val, LLVal(cenv.compile(at(i)))); + val = cenv.engine.builder.CreateBinOp(op, val, LLVal(cenv.compile(at(i)))); return val; - } else if (arg->op == Instruction::ICmp) { - bool isInt = cenv.tenv.type(at(1))->str() == "Int"; - if (isInt) { - return cenv.engine.builder.CreateICmp((CmpInst::Predicate)arg->arg, a, b); - } else { - // Translate to floating point operation - switch (arg->arg) { - case CmpInst::ICMP_EQ: arg->arg = CmpInst::FCMP_OEQ; break; - case CmpInst::ICMP_NE: arg->arg = CmpInst::FCMP_ONE; break; - case CmpInst::ICMP_SGT: arg->arg = CmpInst::FCMP_OGT; break; - case CmpInst::ICMP_SGE: arg->arg = CmpInst::FCMP_OGE; break; - case CmpInst::ICMP_SLT: arg->arg = CmpInst::FCMP_OLT; break; - case CmpInst::ICMP_SLE: arg->arg = CmpInst::FCMP_OLE; break; - default: throw Error("Unknown primitive", exp.loc); - } - return cenv.engine.builder.CreateFCmp((CmpInst::Predicate)arg->arg, a, b); - } } + + // Comparison operations + CmpInst::Predicate pred = (CmpInst::Predicate)0; + if (n == "=") pred = isInt ? CmpInst::ICMP_EQ : CmpInst::FCMP_OEQ; + if (n == "!=") pred = isInt ? CmpInst::ICMP_NE : CmpInst::FCMP_ONE; + if (n == ">") pred = isInt ? CmpInst::ICMP_SGT : CmpInst::FCMP_OGT; + if (n == ">=") pred = isInt ? CmpInst::ICMP_SGE : CmpInst::FCMP_OGE; + if (n == "<") pred = isInt ? CmpInst::ICMP_SLT : CmpInst::FCMP_OLT; + if (n == "<=") pred = isInt ? CmpInst::ICMP_SLE : CmpInst::FCMP_OLE; + if (pred != 0) { + if (isInt) + return cenv.engine.builder.CreateICmp(pred, a, b); + else + return cenv.engine.builder.CreateFCmp(pred, a, b); + } + throw Error("Unknown primitive", exp.loc); } @@ -496,38 +475,8 @@ ASTCdrCall::compile(CEnv& cenv) ***************************************************************************/ void -initLang(PEnv& penv, TEnv& tenv) +initTypes(PEnv& penv, TEnv& tenv) { - penv.reg(true, "fn", PEnv::Handler(parseFn)); - penv.reg(true, "if", PEnv::Handler(parseCall<ASTIf>)); - penv.reg(true, "def", PEnv::Handler(parseCall<ASTDefinition>)); - penv.reg(true, "cons", PEnv::Handler(parseCall<ASTConsCall>)); - penv.reg(true, "car", PEnv::Handler(parseCall<ASTCarCall>)); - penv.reg(true, "cdr", PEnv::Handler(parseCall<ASTCdrCall>)); - - bool trueVal = true; - bool falseVal = false; - penv.reg(false, "true", PEnv::Handler(parseLiteral<bool>, (CArg*)&trueVal)); - penv.reg(false, "false", PEnv::Handler(parseLiteral<bool>, (CArg*)&falseVal)); - - map<string, CArg>* prims = new map<string, CArg>(); - prims->insert(make_pair("+", CArg(Instruction::Add))); - prims->insert(make_pair("-", CArg(Instruction::Sub))); - prims->insert(make_pair("*", CArg(Instruction::Mul))); - prims->insert(make_pair("/", CArg(Instruction::FDiv))); - prims->insert(make_pair("%", CArg(Instruction::FRem))); - prims->insert(make_pair("&", CArg(Instruction::And))); - prims->insert(make_pair("|", CArg(Instruction::Or))); - prims->insert(make_pair("^", CArg(Instruction::Xor))); - prims->insert(make_pair("=", CArg(Instruction::ICmp, CmpInst::ICMP_EQ))); - prims->insert(make_pair("!=", CArg(Instruction::ICmp, CmpInst::ICMP_NE))); - prims->insert(make_pair(">", CArg(Instruction::ICmp, CmpInst::ICMP_SGT))); - prims->insert(make_pair(">=", CArg(Instruction::ICmp, CmpInst::ICMP_SGE))); - prims->insert(make_pair("<", CArg(Instruction::ICmp, CmpInst::ICMP_SLT))); - prims->insert(make_pair("<=", CArg(Instruction::ICmp, CmpInst::ICMP_SLE))); - for (map<string,CArg>::iterator p = prims->begin(); p != prims->end(); ++p) - penv.reg(true, p->first, PEnv::Handler(parseCall<ASTPrimitive>, &p->second)); - tenv.def(penv.sym("Bool"), new AType(penv.sym("Bool"), Type::Int1Ty)); tenv.def(penv.sym("Int"), new AType(penv.sym("Int"), Type::Int32Ty)); tenv.def(penv.sym("Float"), new AType(penv.sym("Float"), Type::FloatTy)); @@ -31,6 +31,7 @@ Funcs ASTConsCall::funcs; std::ostream& err = std::cerr; std::ostream& out = std::cout; + /*************************************************************************** * S-Expression Lexer :: text -> S-Expressions (SExp) * ***************************************************************************/ @@ -93,6 +94,45 @@ readExpression(Cursor& cur, std::istream& in) /*************************************************************************** + * Standard Definitions * + ***************************************************************************/ + +void +initLang(PEnv& penv, TEnv& tenv) +{ + // Literals + static bool trueVal = true; + static bool falseVal = false; + penv.reg(false, "#t", PEnv::Handler(parseLiteral<bool>, &trueVal)); + penv.reg(false, "#f", PEnv::Handler(parseLiteral<bool>, &falseVal)); + + // Special forms + penv.reg(true, "fn", PEnv::Handler(parseFn)); + penv.reg(true, "if", PEnv::Handler(parseCall<ASTIf>)); + penv.reg(true, "def", PEnv::Handler(parseCall<ASTDefinition>)); + penv.reg(true, "cons", PEnv::Handler(parseCall<ASTConsCall>)); + penv.reg(true, "car", PEnv::Handler(parseCall<ASTCarCall>)); + penv.reg(true, "cdr", PEnv::Handler(parseCall<ASTCdrCall>)); + + // Numeric primitives + penv.reg(true, "+", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "-", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "*", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "/", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "%", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "&", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "|", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "^", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "=", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "!=", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, ">", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, ">=", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "<", PEnv::Handler(parseCall<ASTPrimitive>)); + penv.reg(true, "<=", PEnv::Handler(parseCall<ASTPrimitive>)); +} + + +/*************************************************************************** * EVAL/REPL/MAIN * ***************************************************************************/ @@ -115,6 +155,7 @@ main(int argc, char** argv) { PEnv penv; TEnv tenv(penv); + initTypes(penv, tenv); initLang(penv, tenv); CEnv* cenv = newCenv(penv, tenv); @@ -154,7 +195,7 @@ main(int argc, char** argv) is.close(); } - if (files.empty() || args.find("-r") != args.end()) + if (args.find("-r") != args.end() || (files.empty() && args.find("-e") == args.end())) ret = repl(*cenv); a = args.find("-o"); @@ -31,7 +31,6 @@ typedef const void* CType; ///< Compiled type (opaque) typedef void* CFunction; ///< Compiled function (opaque) struct CEngine; ///< Backend data (opaque) -struct CArg; ///< Parser function argument (opaque) #define FOREACH(IT, i, c) for (IT i = (c).begin(); i != (c).end(); ++i) @@ -247,7 +246,7 @@ struct ASTCall : public ASTTuple { /// Definition special form, e.g. "(def x 2)" struct ASTDefinition : public ASTCall { - ASTDefinition(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {} + ASTDefinition(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} void constrain(TEnv& tenv) const; void lift(CEnv& cenv); CValue compile(CEnv& cenv); @@ -255,22 +254,21 @@ struct ASTDefinition : public ASTCall { /// Conditional special form, e.g. "(if cond thenexp elseexp)" struct ASTIf : public ASTCall { - ASTIf(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {} + ASTIf(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} void constrain(TEnv& tenv) const; CValue compile(CEnv& cenv); }; /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { - ASTPrimitive(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t), arg(ca) {} + ASTPrimitive(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} void constrain(TEnv& tenv) const; CValue compile(CEnv& cenv); - CArg* arg; }; /// Cons special form, e.g. "(cons 1 2)" struct ASTConsCall : public ASTCall { - ASTConsCall(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {} + ASTConsCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} AType* functionType(CEnv& cenv); void constrain(TEnv& tenv) const; void lift(CEnv& cenv); @@ -280,14 +278,14 @@ struct ASTConsCall : public ASTCall { /// Car special form, e.g. "(car p)" struct ASTCarCall : public ASTCall { - ASTCarCall(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {} + ASTCarCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} void constrain(TEnv& tenv) const; CValue compile(CEnv& cenv); }; /// Cdr special form, e.g. "(cdr p)" struct ASTCdrCall : public ASTCall { - ASTCdrCall(const SExp& e, const ASTTuple& t, CArg* ca=0) : ASTCall(e, t) {} + ASTCdrCall(const SExp& e, const ASTTuple& t) : ASTCall(e, t) {} void constrain(TEnv& tenv) const; CValue compile(CEnv& cenv); }; @@ -299,8 +297,8 @@ struct ASTCdrCall : public ASTCall { // Parse Time Environment (symbol table) struct PEnv : private map<const string, ASTSymbol*> { - typedef AST* (*PF)(PEnv&, const SExp&, CArg*); // Parse Function - struct Handler { Handler(PF f, CArg* a=0) : func(f), arg(a) {} PF func; CArg* arg; }; + typedef AST* (*PF)(PEnv&, const SExp&, void*); // Parse Function + struct Handler { Handler(PF f, void* a=0) : func(f), arg(a) {} PF func; void* arg; }; map<const string, Handler> aHandlers; ///< Atom parse functions map<const string, Handler> lHandlers; ///< List parse functions void reg(bool list, const string& s, const Handler& h) { @@ -358,20 +356,20 @@ parseExpression(PEnv& penv, const SExp& exp) template<typename C> inline AST* -parseCall(PEnv& penv, const SExp& exp, CArg* arg) +parseCall(PEnv& penv, const SExp& exp, void* arg) { - return new C(exp, pmap(penv, exp.list), arg); + return new C(exp, pmap(penv, exp.list)); } template<typename T> inline AST* -parseLiteral(PEnv& penv, const SExp& exp, CArg* arg) +parseLiteral(PEnv& penv, const SExp& exp, void* arg) { return new ASTLiteral<T>(*reinterpret_cast<T*>(arg)); } inline AST* -parseFn(PEnv& penv, const SExp& exp, CArg* arg) +parseFn(PEnv& penv, const SExp& exp, void* arg) { SExp::List::const_iterator a = exp.list.begin(); ++a; return new ASTClosure( @@ -479,6 +477,7 @@ private: * EVAL/REPL/MAIN * ***************************************************************************/ +void initTypes(PEnv& penv, TEnv& tenv); void initLang(PEnv& penv, TEnv& tenv); CEnv* newCenv(PEnv& penv, TEnv& tenv); int eval(CEnv& cenv, const string& name, istream& is); @@ -83,6 +83,54 @@ ASTIf::constrain(TEnv& tenv) const } void +ASTPrimitive::constrain(TEnv& tenv) const +{ + const string n = dynamic_cast<ASTSymbol*>(at(0))->str(); + enum { ARITHMETIC, BINARY, BITWISE, COMPARISON } type; + if (n == "+" || n == "-" || n == "*" || n == "/") + type = ARITHMETIC; + else if (n == "%") + type = BINARY; + else if (n == "&" || n == "|" || n == "^") + type = BITWISE; + else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=") + type = COMPARISON; + else + throw Error((format("unknown primitive `%1%'") % n).str(), exp.loc); + + FOREACH(const_iterator, p, *this) + (*p)->constrain(tenv); + + switch (type) { + case ARITHMETIC: + if (size() < 3) + throw Error((format("`%1%' requires at least 2 arguments") % n).str(), exp.loc); + for (size_t i = 1; i < size(); ++i) + tenv.constrain(this, tenv.type(at(i))); + break; + case BINARY: + if (size() != 3) + throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc); + tenv.constrain(this, tenv.type(at(1))); + tenv.constrain(this, tenv.type(at(2))); + break; + case BITWISE: + if (size() != 3) + throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc); + tenv.constrain(this, tenv.named("Bool")); + tenv.constrain(at(1), tenv.named("Bool")); + tenv.constrain(at(2), tenv.named("Bool")); + break; + case COMPARISON: + if (size() != 3) + throw Error((format("`%1%' requires exactly 2 arguments") % n).str(), exp.loc); + tenv.constrain(this, tenv.named("Bool")); + tenv.constrain(at(1), tenv.type(at(2))); + break; + } +} + +void ASTConsCall::constrain(TEnv& tenv) const { AType* t = new AType(ASTTuple(tenv.penv.sym("Pair"), 0)); @@ -174,7 +222,7 @@ TEnv::unify(const Constraints& constraints) // TAPL 22.4 } else if (t->var() && !s->contains(t)) { substConstraints(cp, t, s); return compose(unify(cp), TSubst(t, s)); - } else if (s->size() == t->size()) { + } else if (s->kind == AType::EXPR && s->kind == t->kind && s->size() == t->size()) { for (size_t i = 0; i < s->size(); ++i) { AType* si = dynamic_cast<AType*>(s->at(i)); AType* ti = dynamic_cast<AType*>(t->at(i)); |