From 57951dddc871bb8afd681f8205db29fb653b3a58 Mon Sep 17 00:00:00 2001 From: David Robillard Date: Sun, 25 Jan 2009 22:09:08 +0000 Subject: Floating point and comparison primitives. git-svn-id: http://svn.drobilla.net/resp/llvm-lisp@14 ad02d1e2-f140-0410-9f75-f8b11f17cedd --- ll.cpp | 133 +++++++++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 88 insertions(+), 45 deletions(-) diff --git a/ll.cpp b/ll.cpp index 0b81628..a9e80c5 100644 --- a/ll.cpp +++ b/ll.cpp @@ -21,13 +21,14 @@ #include #include #include +#include #include #include #include -#include #include "llvm/Analysis/Verifier.h" #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" +#include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/ModuleProvider.h" #include "llvm/PassManager.h" @@ -192,6 +193,9 @@ struct CEnv { Vals vals; }; +/// LLVM Operation +struct Op { Op(int o=0, int a=0) : op(o), arg(a) {} int op; int arg; }; + /*************************************************************************** @@ -280,7 +284,7 @@ LITERAL(int32_t, Type::Int32Ty, "Int", ConstantInt::get(Type::Int32Ty, val, tr LITERAL(float, Type::FloatTy, "Float", ConstantFP::get(Type::FloatTy, val)); LITERAL(bool, Type::Int1Ty, "Bool", ConstantInt::get(Type::Int1Ty, val, false)); -typedef unsigned UD; // User Data passed to registered parse functions +typedef Op UD; // User Data argument for parse functions // Parse Time Environment (symbol table) struct PEnv : private map { @@ -364,13 +368,12 @@ struct ASTIf : public ASTCall { /// Primitive (builtin arithmetic function), e.g. "(+ 2 3)" struct ASTPrimitive : public ASTCall { - ASTPrimitive(const vector& c, Instruction::BinaryOps o) : ASTCall(c), op(o) {} - AType* type(CEnv& cenv) { - if (tup.size() <= 1) throw SyntaxError("Primitive call with no arguments"); - return tup[1]->type(cenv); // FIXME: Ensure argument types are equivalent - } + ASTPrimitive(const vector& c, unsigned o, unsigned a=0) + : ASTCall(c), op(o), arg(a) {} + AType* type(CEnv& cenv); Value* compile(CEnv& cenv); - Instruction::BinaryOps op; + unsigned op; + unsigned arg; }; AType* @@ -440,7 +443,7 @@ parseDef(PEnv& penv, const list& c, UD) static AST* parsePrim(PEnv& penv, const list& c, UD data) - { return new ASTPrimitive(pmap(penv, c), (Instruction::BinaryOps)data); } + { return new ASTPrimitive(pmap(penv, c), data.op, data.arg); } static ASTTuple* parsePrototype(PEnv& penv, const SExp& e, UD) @@ -451,7 +454,7 @@ parseFn(PEnv& penv, const list& c, UD) { list::const_iterator a = c.begin(); ++a; return new ASTClosure( - parsePrototype(penv, *a++, 0), + parsePrototype(penv, *a++, UD()), parseExpression(penv, *a++)); } @@ -680,30 +683,67 @@ ASTClosure::compile(CEnv& cenv) return func; } +#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End) + +AType* +ASTPrimitive::type(CEnv& cenv) +{ + if (OP_IS_A(op, Instruction::BinaryOps)) { + if (tup.size() <= 1) throw SyntaxError("Primitive call with 0 args"); + return tup[1]->type(cenv); + } else if (op == Instruction::ICmp) { + if (tup.size() != 3) throw SyntaxError("Comparison call with != 2 args"); + return new AType("Bool", Type::Int1Ty); + } else { + throw CompileError("Unknown primitive"); + } +} + Value* ASTPrimitive::compile(CEnv& cenv) { - size_t np = 0; vector params(tup.size() - 1); - vector::const_iterator a = tup.begin(); - for (++a; a != tup.end(); ++a) - params[np++] = (*a)->compile(cenv); - - switch (params.size()) { - case 0: - throw SyntaxError("Primitive expects at least 1 argument"); - case 1: - return params[0]; - default: - Value* val = cenv.builder.CreateBinOp(op, params[0], params[1]); - for (size_t i = 2; i < params.size(); ++i) - val = cenv.builder.CreateBinOp(op, val, params[i]); - return val; + for (size_t i = 1; i < tup.size(); ++i) + params[i-1] = tup[i]->compile(cenv); + + Value* a = tup[1]->compile(cenv); + Value* b = tup[2]->compile(cenv); + + if (OP_IS_A(op, Instruction::BinaryOps)) { + const Instruction::BinaryOps bo = (Instruction::BinaryOps)op; + switch (params.size()) { + case 0: + throw SyntaxError("Primitive expects at least 1 argument"); + case 1: + return params[0]; + default: + Value* val = cenv.builder.CreateBinOp(bo, a, b); + for (size_t i = 2; i < params.size(); ++i) + val = cenv.builder.CreateBinOp(bo, val, params[i]); + return val; + } + } else if (op == Instruction::ICmp) { + bool isInt = tup[1]->type(cenv)->str(cenv) == "(Int)"; + if (isInt) { + return cenv.builder.CreateICmp((CmpInst::Predicate)arg, a, b); + } else { + // Translate to floating point operation + switch (arg) { + case CmpInst::ICMP_EQ: arg = CmpInst::FCMP_OEQ; break; + case CmpInst::ICMP_NE: arg = CmpInst::FCMP_ONE; break; + case CmpInst::ICMP_SGT: arg = CmpInst::FCMP_OGT; break; + case CmpInst::ICMP_SGE: arg = CmpInst::FCMP_OGE; break; + case CmpInst::ICMP_SLT: arg = CmpInst::FCMP_OLT; break; + case CmpInst::ICMP_SLE: arg = CmpInst::FCMP_OLE; break; + default: throw CompileError("Unknown primitive"); + } + return cenv.builder.CreateFCmp((CmpInst::Predicate)arg, a, b); + } } + throw CompileError("Unknown primitive"); } - /*************************************************************************** * REPL * ***************************************************************************/ @@ -711,18 +751,25 @@ ASTPrimitive::compile(CEnv& cenv) int main() { +#define PRIM(O, A) PEnv::Parser(parsePrim, Op(Instruction:: O, A)) PEnv penv; - penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, 0)); - penv.reg(penv.sym("if"), PEnv::Parser(parseIf, 0)); - penv.reg(penv.sym("def"), PEnv::Parser(parseDef, 0)); - penv.reg(penv.sym("+"), PEnv::Parser(parsePrim, Instruction::Add)); - penv.reg(penv.sym("-"), PEnv::Parser(parsePrim, Instruction::Sub)); - penv.reg(penv.sym("*"), PEnv::Parser(parsePrim, Instruction::Mul)); - penv.reg(penv.sym("/"), PEnv::Parser(parsePrim, Instruction::FDiv)); - penv.reg(penv.sym("%"), PEnv::Parser(parsePrim, Instruction::FRem)); - penv.reg(penv.sym("&"), PEnv::Parser(parsePrim, Instruction::And)); - penv.reg(penv.sym("|"), PEnv::Parser(parsePrim, Instruction::Or)); - penv.reg(penv.sym("^"), PEnv::Parser(parsePrim, Instruction::Xor)); + penv.reg(penv.sym("fn"), PEnv::Parser(parseFn, Op())); + penv.reg(penv.sym("if"), PEnv::Parser(parseIf, Op())); + penv.reg(penv.sym("def"), PEnv::Parser(parseDef, Op())); + penv.reg(penv.sym("+"), PRIM(Add, 0)); + penv.reg(penv.sym("-"), PRIM(Sub, 0)); + penv.reg(penv.sym("*"), PRIM(Mul, 0)); + penv.reg(penv.sym("/"), PRIM(FDiv, 0)); + penv.reg(penv.sym("%"), PRIM(FRem, 0)); + penv.reg(penv.sym("&"), PRIM(And, 0)); + penv.reg(penv.sym("|"), PRIM(Or, 0)); + penv.reg(penv.sym("^"), PRIM(Xor, 0)); + penv.reg(penv.sym("="), PRIM(ICmp, CmpInst::ICMP_EQ)); + penv.reg(penv.sym("!="), PRIM(ICmp, CmpInst::ICMP_NE)); + penv.reg(penv.sym(">"), PRIM(ICmp, CmpInst::ICMP_SGT)); + penv.reg(penv.sym(">="), PRIM(ICmp, CmpInst::ICMP_SGE)); + penv.reg(penv.sym("<"), PRIM(ICmp, CmpInst::ICMP_SLT)); + penv.reg(penv.sym("<="), PRIM(ICmp, CmpInst::ICMP_SLE)); Module* module = new Module("repl"); ExecutionEngine* engine = ExecutionEngine::create(module); @@ -778,14 +825,10 @@ main() Value* val = body->compile(cenv); std::cout << val; } - std::cout << " :: " << body->type(cenv)->str(cenv) << endl; - - } catch (SyntaxError e) { - std::cerr << "Syntax error: " << e.what() << endl; - } catch (TypeError e) { - std::cerr << "Type error: " << e.what() << endl; - } catch (CompileError e) { - std::cerr << "Compile error: " << e.what() << endl; + std::cout << " : " << body->type(cenv)->str(cenv) << endl; + + } catch (Error e) { + std::cerr << "Error: " << e.what() << endl; } } -- cgit v1.2.1