From ecef8f697c66e15b85beb934d2b617b915a97aab Mon Sep 17 00:00:00 2001
From: David Robillard <d@drobilla.net>
Date: Fri, 6 Mar 2009 22:12:36 +0000
Subject: Cleanup and de-llvm-ify primitive stuff. Fix type inference (only
 treat actual type expressions as type expressions).

git-svn-id: http://svn.drobilla.net/resp/tuplr@64 ad02d1e2-f140-0410-9f75-f8b11f17cedd
---
 llvm.cpp | 121 ++++++++++++++++++---------------------------------------------
 1 file changed, 35 insertions(+), 86 deletions(-)

(limited to 'llvm.cpp')

diff --git a/llvm.cpp b/llvm.cpp
index 2e53b7f..03c49f9 100644
--- a/llvm.cpp
+++ b/llvm.cpp
@@ -46,12 +46,6 @@ struct CEngine {
 	IRBuilder<>      builder;
 };
 
-struct CArg {
-	CArg(int o=0, int a=0) : op(o), arg(a) {}
-	int op;
-	int arg;
-};
-
 
 /***************************************************************************
  * Typing                                                                  *
@@ -72,29 +66,6 @@ AType::ctype()
 	}
 }
 
-#define OP_IS_A(o, t) ((o) >= t ## Begin && (o) < t ## End)
-
-void
-ASTPrimitive::constrain(TEnv& tenv) const
-{
-	FOREACH(const_iterator, p, *this)
-		(*p)->constrain(tenv);
-	if (OP_IS_A(arg->op, Instruction::BinaryOps)) {
-		if (size() <= 2) throw Error((format("`%1%' requires at least 2 arguments")
-				% at(0)->str()).str(), exp.loc);
-		AType* tvar = tenv.type(this);
-		for (size_t i = 1; i < size(); ++i)
-			tenv.constrain(at(i), tvar);
-	} else if (arg->op == Instruction::ICmp) {
-		if (size() != 3) throw Error((format("`%1%' requires exactly 2 arguments")
-			   % at(0)->str()).str(), exp.loc);
-		tenv.constrain(at(1), tenv.type(at(2)));
-		tenv.constrain(this, tenv.named("Bool"));
-	} else {
-		throw Error((format("unknown primitive `%1%'") % at(0)->str()).str(), exp.loc);
-	}
-}
-
 
 /***************************************************************************
  * Code Generation                                                         *
@@ -370,35 +341,43 @@ ASTIf::compile(CEnv& cenv)
 CValue
 ASTPrimitive::compile(CEnv& cenv)
 {
-	Value* a = LLVal(cenv.compile(at(1)));
-	Value* b = LLVal(cenv.compile(at(2)));
-
-	if (OP_IS_A(arg->op, Instruction::BinaryOps)) {
-		const Instruction::BinaryOps bo = (Instruction::BinaryOps)arg->op;
-		if (size() == 2)
-			return cenv.compile(at(1));
-		Value* val = cenv.engine.builder.CreateBinOp(bo, a, b);
+	Value*       a     = LLVal(cenv.compile(at(1)));
+	Value*       b     = LLVal(cenv.compile(at(2)));
+	bool         isInt = cenv.tenv.type(at(1))->str() == "Int";
+	const string n     = dynamic_cast<ASTSymbol*>(at(0))->str();
+
+	// Binary arithmetic operations
+	Instruction::BinaryOps op = (Instruction::BinaryOps)0;
+	if (n == "+") op = Instruction::Add;
+	if (n == "-") op = Instruction::Sub;
+	if (n == "*") op = Instruction::Mul;
+	if (n == "&") op = Instruction::And;
+	if (n == "|") op = Instruction::Or;
+	if (n == "^") op = Instruction::Xor;
+	if (n == "/") op = isInt ? Instruction::SDiv : Instruction::FDiv;
+	if (n == "%") op = isInt ? Instruction::SRem : Instruction::FRem;
+	if (op != 0) {
+		Value* val = cenv.engine.builder.CreateBinOp(op, a, b);
 		for (size_t i = 3; i < size(); ++i)
-			val = cenv.engine.builder.CreateBinOp(bo, val, LLVal(cenv.compile(at(i))));
+			val = cenv.engine.builder.CreateBinOp(op, val, LLVal(cenv.compile(at(i))));
 		return val;
-	} else if (arg->op == Instruction::ICmp) {
-		bool isInt = cenv.tenv.type(at(1))->str() == "Int";
-		if (isInt) {
-			return cenv.engine.builder.CreateICmp((CmpInst::Predicate)arg->arg, a, b);
-		} else {
-			// Translate to floating point operation
-			switch (arg->arg) {
-			case CmpInst::ICMP_EQ:  arg->arg = CmpInst::FCMP_OEQ; break;
-			case CmpInst::ICMP_NE:  arg->arg = CmpInst::FCMP_ONE; break;
-			case CmpInst::ICMP_SGT: arg->arg = CmpInst::FCMP_OGT; break;
-			case CmpInst::ICMP_SGE: arg->arg = CmpInst::FCMP_OGE; break;
-			case CmpInst::ICMP_SLT: arg->arg = CmpInst::FCMP_OLT; break;
-			case CmpInst::ICMP_SLE: arg->arg = CmpInst::FCMP_OLE; break;
-			default: throw Error("Unknown primitive", exp.loc);
-			}
-			return cenv.engine.builder.CreateFCmp((CmpInst::Predicate)arg->arg, a, b);
-		}
 	}
+
+	// Comparison operations
+	CmpInst::Predicate pred = (CmpInst::Predicate)0;
+	if (n == "=")  pred = isInt ? CmpInst::ICMP_EQ  : CmpInst::FCMP_OEQ;
+	if (n == "!=") pred = isInt ? CmpInst::ICMP_NE  : CmpInst::FCMP_ONE;
+	if (n == ">")  pred = isInt ? CmpInst::ICMP_SGT : CmpInst::FCMP_OGT;
+	if (n == ">=") pred = isInt ? CmpInst::ICMP_SGE : CmpInst::FCMP_OGE;
+	if (n == "<")  pred = isInt ? CmpInst::ICMP_SLT : CmpInst::FCMP_OLT;
+	if (n == "<=") pred = isInt ? CmpInst::ICMP_SLE : CmpInst::FCMP_OLE;
+	if (pred != 0) {
+		if (isInt)
+			return cenv.engine.builder.CreateICmp(pred, a, b);
+		else
+			return cenv.engine.builder.CreateFCmp(pred, a, b);
+	}
+	
 	throw Error("Unknown primitive", exp.loc);
 }
 
@@ -496,38 +475,8 @@ ASTCdrCall::compile(CEnv& cenv)
  ***************************************************************************/
 
 void
-initLang(PEnv& penv, TEnv& tenv)
+initTypes(PEnv& penv, TEnv& tenv)
 {
-	penv.reg(true, "fn",   PEnv::Handler(parseFn));
-	penv.reg(true, "if",   PEnv::Handler(parseCall<ASTIf>));
-	penv.reg(true, "def",  PEnv::Handler(parseCall<ASTDefinition>));
-	penv.reg(true, "cons", PEnv::Handler(parseCall<ASTConsCall>));
-	penv.reg(true, "car",  PEnv::Handler(parseCall<ASTCarCall>));
-	penv.reg(true, "cdr",  PEnv::Handler(parseCall<ASTCdrCall>));
-
-	bool trueVal  = true;
-	bool falseVal = false;
-	penv.reg(false, "true",  PEnv::Handler(parseLiteral<bool>, (CArg*)&trueVal));
-	penv.reg(false, "false", PEnv::Handler(parseLiteral<bool>, (CArg*)&falseVal));
-
-	map<string, CArg>* prims = new map<string, CArg>();
-	prims->insert(make_pair("+",  CArg(Instruction::Add)));
-	prims->insert(make_pair("-",  CArg(Instruction::Sub)));
-	prims->insert(make_pair("*",  CArg(Instruction::Mul)));
-	prims->insert(make_pair("/",  CArg(Instruction::FDiv)));
-	prims->insert(make_pair("%",  CArg(Instruction::FRem)));
-	prims->insert(make_pair("&",  CArg(Instruction::And)));
-	prims->insert(make_pair("|",  CArg(Instruction::Or)));
-	prims->insert(make_pair("^",  CArg(Instruction::Xor)));
-	prims->insert(make_pair("=",  CArg(Instruction::ICmp, CmpInst::ICMP_EQ)));
-	prims->insert(make_pair("!=", CArg(Instruction::ICmp, CmpInst::ICMP_NE)));
-	prims->insert(make_pair(">",  CArg(Instruction::ICmp, CmpInst::ICMP_SGT)));
-	prims->insert(make_pair(">=", CArg(Instruction::ICmp, CmpInst::ICMP_SGE)));
-	prims->insert(make_pair("<",  CArg(Instruction::ICmp, CmpInst::ICMP_SLT)));
-	prims->insert(make_pair("<=", CArg(Instruction::ICmp, CmpInst::ICMP_SLE)));
-	for (map<string,CArg>::iterator p = prims->begin(); p != prims->end(); ++p)
-		penv.reg(true, p->first, PEnv::Handler(parseCall<ASTPrimitive>, &p->second));
-	
 	tenv.def(penv.sym("Bool"),  new AType(penv.sym("Bool"),  Type::Int1Ty));
 	tenv.def(penv.sym("Int"),   new AType(penv.sym("Int"),   Type::Int32Ty));
 	tenv.def(penv.sym("Float"), new AType(penv.sym("Float"), Type::FloatTy));
-- 
cgit v1.2.1