aboutsummaryrefslogtreecommitdiffstats
path: root/tuplr_llvm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tuplr_llvm.cpp')
-rw-r--r--tuplr_llvm.cpp164
1 files changed, 94 insertions, 70 deletions
diff --git a/tuplr_llvm.cpp b/tuplr_llvm.cpp
index a807f75..4ed3ba5 100644
--- a/tuplr_llvm.cpp
+++ b/tuplr_llvm.cpp
@@ -15,10 +15,30 @@
* along with Tuplr. If not, see <http://www.gnu.org/licenses/>.
*/
-#include <sstream>
#include <fstream>
+#include <sstream>
#include "tuplr.hpp"
-#include "tuplr_llvm.hpp"
+#include "llvm/Analysis/Verifier.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/Instructions.h"
+#include "llvm/Module.h"
+#include "llvm/ModuleProvider.h"
+#include "llvm/PassManager.h"
+#include "llvm/Support/IRBuilder.h"
+#include "llvm/Target/TargetData.h"
+#include "llvm/Transforms/Scalar.h"
+
+llvm::Value* LLVal(CValue v) { return static_cast<llvm::Value*>(v); }
+const llvm::Type* LLType(CType t) { return static_cast<const llvm::Type*>(t); }
+llvm::Function* LLFunc(CFunction f) { return static_cast<llvm::Function*>(f); }
+
+struct CEngine {
+ CEngine();
+ llvm::Module* module;
+ llvm::ExecutionEngine* engine;
+ llvm::IRBuilder<> builder;
+};
using namespace llvm;
using namespace std;
@@ -29,14 +49,14 @@ using boost::format;
* Abstract Syntax Tree *
***************************************************************************/
-const CType*
+CType
AType::type()
{
if (at(0)->str() == "Pair") {
- vector<const CType*> types;
+ vector<const Type*> types;
for (size_t i = 1; i < size(); ++i) {
assert(dynamic_cast<AType*>(at(i)));
- types.push_back(((AType*)at(i))->type());
+ types.push_back(LLType(((AType*)at(i))->type()));
}
return PointerType::get(StructType::get(types, false), 0);
} else {
@@ -114,22 +134,22 @@ CEnv::~CEnv()
delete _pimpl;
}
-CValue*
+CValue
CEnv::compile(AST* obj)
{
- CValue** v = vals.ref(obj);
+ CValue* v = vals.ref(obj);
return (v) ? *v : vals.def(obj, obj->compile(*this));
}
void
-CEnv::optimise(Function& f)
+CEnv::optimise(CFunction f)
{
- verifyFunction(f);
- _pimpl->opt.run(f);
+ verifyFunction(*static_cast<Function*>(f));
+ _pimpl->opt.run(*static_cast<Function*>(f));
}
#define LITERAL(CT, NAME, COMPILED) \
-template<> CValue* \
+template<> CValue \
ASTLiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
template<> void \
ASTLiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); }
@@ -140,20 +160,20 @@ LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val))
LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false))
static Function*
-compileFunction(CEnv& cenv, const std::string& name, const CType* retT, const ASTTuple& prot,
+compileFunction(CEnv& cenv, const std::string& name, CType retT, const ASTTuple& prot,
const vector<string> argNames=vector<string>())
{
Function::LinkageTypes linkage = Function::ExternalLinkage;
- vector<const CType*> cprot;
+ vector<const Type*> cprot;
for (size_t i = 0; i < prot.size(); ++i) {
AType* at = cenv.tenv.type(prot.at(i));
if (!at->type() || at->var()) throw Error("function parameter is untyped");
- cprot.push_back(at->type());
+ cprot.push_back(LLType(at->type()));
}
if (!retT) throw Error("function return is untyped");
- FunctionType* fT = FunctionType::get(retT, cprot, false);
+ FunctionType* fT = FunctionType::get(static_cast<const Type*>(retT), cprot, false);
Function* f = Function::Create(fT, linkage, name, cenv.engine.module);
if (f->getName() != name) {
@@ -176,7 +196,7 @@ compileFunction(CEnv& cenv, const std::string& name, const CType* retT, const AS
return f;
}
-CValue*
+CValue
ASTSymbol::compile(CEnv& cenv)
{
AST** c = cenv.code.ref(this);
@@ -203,7 +223,7 @@ ASTClosure::lift(CEnv& cenv)
Function* f = compileFunction(cenv, name, cenv.tenv.type(at(2))->type(), *prot());
// Bind argument values in CEnv
- vector<CValue*> args;
+ vector<Value*> args;
const_iterator p = prot()->begin();
for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
cenv.vals.def(dynamic_cast<ASTSymbol*>(*p), &*a);
@@ -211,9 +231,9 @@ ASTClosure::lift(CEnv& cenv)
// Write function body
try {
cenv.precompile(this, f); // Define our value first for recursion
- CValue* retVal = cenv.compile(at(2));
- cenv.engine.builder.CreateRet(retVal); // Finish function
- cenv.optimise(*f);
+ CValue retVal = cenv.compile(at(2));
+ cenv.engine.builder.CreateRet(LLVal(retVal)); // Finish function
+ cenv.optimise(LLFunc(f));
funcs.insert(type, f);
} catch (Error& e) {
f->eraseFromParent(); // Error reading body, remove function
@@ -223,7 +243,7 @@ ASTClosure::lift(CEnv& cenv)
cenv.pop();
}
-CValue*
+CValue
ASTClosure::compile(CEnv& cenv)
{
return funcs.find(cenv.tenv.type(this));
@@ -258,7 +278,7 @@ ASTCall::lift(CEnv& cenv)
cenv.pop(); // Restore environment
}
-CValue*
+CValue
ASTCall::compile(CEnv& cenv)
{
ASTClosure* c = dynamic_cast<ASTClosure*>(at(0));
@@ -268,12 +288,12 @@ ASTCall::compile(CEnv& cenv)
}
assert(c);
- Function* f = dynamic_cast<Function*>(cenv.compile(c));
+ Function* f = dynamic_cast<Function*>(LLVal(cenv.compile(c)));
if (!f) throw Error("callee failed to compile", exp.loc);
- vector<CValue*> params(size() - 1);
+ vector<Value*> params(size() - 1);
for (size_t i = 1; i < size(); ++i)
- params[i-1] = cenv.compile(at(i));
+ params[i-1] = LLVal(cenv.compile(at(i)));
return cenv.engine.builder.CreateCall(f, params.begin(), params.end(), "calltmp");
}
@@ -287,23 +307,23 @@ ASTDefinition::lift(CEnv& cenv)
at(2)->lift(cenv);
}
-CValue*
+CValue
ASTDefinition::compile(CEnv& cenv)
{
return cenv.compile(at(2));
}
-CValue*
+CValue
ASTIf::compile(CEnv& cenv)
{
- typedef vector< pair<CValue*, BasicBlock*> > Branches;
+ typedef vector< pair<Value*, BasicBlock*> > Branches;
Function* parent = cenv.engine.builder.GetInsertBlock()->getParent();
BasicBlock* mergeBB = BasicBlock::Create("endif");
BasicBlock* nextBB = NULL;
Branches branches;
ostringstream ss;
for (size_t i = 1; i < size() - 1; i += 2) {
- CValue* condV = cenv.compile(at(i));
+ Value* condV = LLVal(cenv.compile(at(i)));
ss.str(""); ss << "then" << ((i + 1) / 2);
BasicBlock* thenBB = BasicBlock::Create(ss.str());
@@ -316,7 +336,7 @@ ASTIf::compile(CEnv& cenv)
// Emit then block for this condition
parent->getBasicBlockList().push_back(thenBB);
cenv.engine.builder.SetInsertPoint(thenBB);
- CValue* thenV = cenv.compile(at(i + 1));
+ Value* thenV = LLVal(cenv.compile(at(i + 1)));
cenv.engine.builder.CreateBr(mergeBB);
branches.push_back(make_pair(thenV, cenv.engine.builder.GetInsertBlock()));
@@ -326,14 +346,14 @@ ASTIf::compile(CEnv& cenv)
// Emit else block
cenv.engine.builder.SetInsertPoint(nextBB);
- CValue* elseV = cenv.compile(at(size() - 1));
+ Value* elseV = LLVal(cenv.compile(at(size() - 1)));
cenv.engine.builder.CreateBr(mergeBB);
branches.push_back(make_pair(elseV, cenv.engine.builder.GetInsertBlock()));
// Emit merge block (Phi node)
parent->getBasicBlockList().push_back(mergeBB);
cenv.engine.builder.SetInsertPoint(mergeBB);
- PHINode* pn = cenv.engine.builder.CreatePHI(cenv.tenv.type(this)->type(), "ifval");
+ PHINode* pn = cenv.engine.builder.CreatePHI(LLType(cenv.tenv.type(this)->type()), "ifval");
for (Branches::iterator i = branches.begin(); i != branches.end(); ++i)
pn->addIncoming(i->first, i->second);
@@ -341,19 +361,19 @@ ASTIf::compile(CEnv& cenv)
return pn;
}
-CValue*
+CValue
ASTPrimitive::compile(CEnv& cenv)
{
- CValue* a = cenv.compile(at(1));
- CValue* b = cenv.compile(at(2));
+ Value* a = LLVal(cenv.compile(at(1)));
+ Value* b = LLVal(cenv.compile(at(2)));
if (OP_IS_A(arg.op, Instruction::BinaryOps)) {
const Instruction::BinaryOps bo = (Instruction::BinaryOps)arg.op;
if (size() == 2)
return cenv.compile(at(1));
- CValue* val = cenv.engine.builder.CreateBinOp(bo, a, b);
+ Value* val = cenv.engine.builder.CreateBinOp(bo, a, b);
for (size_t i = 3; i < size(); ++i)
- val = cenv.engine.builder.CreateBinOp(bo, val, cenv.compile(at(i)));
+ val = cenv.engine.builder.CreateBinOp(bo, val, LLVal(cenv.compile(at(i))));
return val;
} else if (arg.op == Instruction::ICmp) {
bool isInt = cenv.tenv.type(at(1))->str() == "Int";
@@ -396,17 +416,19 @@ ASTConsCall::lift(CEnv& cenv)
ASTTuple* prot = new ASTTuple(at(1), at(2), NULL);
- vector<const CType*> types;
+ vector<const Type*> types;
size_t sz = 0;
for (size_t i = 1; i < size(); ++i) {
- const CType* t = cenv.tenv.type(at(i))->type();
+ const Type* t = LLType(cenv.tenv.type(at(i))->type());
types.push_back(t);
sz += t->getPrimitiveSizeInBits();
}
sz = (sz % 8 == 0) ? sz / 8 : sz / 8 + 1;
+
+ llvm::IRBuilder<>& builder = cenv.engine.builder;
StructType* sT = StructType::get(types, false);
- CType* pT = PointerType::get(sT, 0);
+ Type* pT = PointerType::get(sT, 0);
// Write function declaration
vector<string> argNames;
@@ -414,49 +436,51 @@ ASTConsCall::lift(CEnv& cenv)
argNames.push_back("cdr");
Function* func = compileFunction(cenv, cenv.gensym("cons"), pT, *prot, argNames);
- CValue* mem = cenv.engine.builder.CreateCall(cenv.alloc, ConstantInt::get(Type::Int32Ty, sz), "mem");
- CValue* cell = cenv.engine.builder.CreateBitCast(mem, pT, "cell");
- CValue* s = cenv.engine.builder.CreateGEP(cell, ConstantInt::get(Type::Int32Ty, 0), "pair");
- CValue* carP = cenv.engine.builder.CreateStructGEP(s, 0, "car");
- CValue* cdrP = cenv.engine.builder.CreateStructGEP(s, 1, "cdr");
+ Value* mem = builder.CreateCall(LLVal(cenv.alloc), ConstantInt::get(Type::Int32Ty, sz), "mem");
+ Value* cell = builder.CreateBitCast(mem, pT, "cell");
+ Value* s = builder.CreateGEP(cell, ConstantInt::get(Type::Int32Ty, 0), "pair");
+ Value* carP = builder.CreateStructGEP(s, 0, "car");
+ Value* cdrP = builder.CreateStructGEP(s, 1, "cdr");
+
Function::arg_iterator ai = func->arg_begin();
Value& carArg = *ai++;
Value& cdrArg = *ai++;
- cenv.engine.builder.CreateStore(&carArg, carP);
- cenv.engine.builder.CreateStore(&cdrArg, cdrP);
- cenv.engine.builder.CreateRet(cell);
- cenv.optimise(*func);
-
+ builder.CreateStore(&carArg, carP);
+ builder.CreateStore(&cdrArg, cdrP);
+ builder.CreateRet(cell);
+
+ cenv.optimise(func);
funcs.insert(funcType, func);
}
-CValue*
+CValue
ASTConsCall::compile(CEnv& cenv)
{
- vector<CValue*> params(size() - 1);
+ vector<Value*> params(size() - 1);
for (size_t i = 1; i < size(); ++i)
- params[i-1] = cenv.compile(at(i));
+ params[i-1] = LLVal(cenv.compile(at(i)));
- return cenv.engine.builder.CreateCall(funcs.find(functionType(cenv)), params.begin(), params.end());
+ return cenv.engine.builder.CreateCall(LLFunc(funcs.find(functionType(cenv))),
+ params.begin(), params.end());
}
-CValue*
+CValue
ASTCarCall::compile(CEnv& cenv)
{
AST** arg = cenv.code.ref(at(1));
- CValue* sP = arg ? (*arg)->compile(cenv) : at(1)->compile(cenv);
- CValue* s = cenv.engine.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair");
- CValue* carP = cenv.engine.builder.CreateStructGEP(s, 0, "car");
+ Value* sP = LLVal(arg ? (*arg)->compile(cenv) : at(1)->compile(cenv));
+ Value* s = cenv.engine.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair");
+ Value* carP = cenv.engine.builder.CreateStructGEP(s, 0, "car");
return cenv.engine.builder.CreateLoad(carP);
}
-CValue*
+CValue
ASTCdrCall::compile(CEnv& cenv)
{
AST** arg = cenv.code.ref(at(1));
- CValue* sP = arg ? (*arg)->compile(cenv) : at(1)->compile(cenv);
- CValue* s = cenv.engine.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair");
- CValue* cdrP = cenv.engine.builder.CreateStructGEP(s, 1, "cdr");
+ Value* sP = LLVal(arg ? (*arg)->compile(cenv) : at(1)->compile(cenv));
+ Value* s = cenv.engine.builder.CreateGEP(sP, ConstantInt::get(Type::Int32Ty, 0), "pair");
+ Value* cdrP = cenv.engine.builder.CreateStructGEP(s, 1, "cdr");
return cenv.engine.builder.CreateLoad(cdrP);
}
@@ -503,20 +527,20 @@ eval(CEnv& cenv, const string& name, istream& is)
if (!resultType || resultType->var()) throw Error("body is undefined/untyped", cursor);
- const CType* ctype = resultType->type();
+ CType ctype = resultType->type();
if (!ctype) throw Error("body has no system type", cursor);
// Create function for top-level of program
Function* f = compileFunction(cenv, cenv.gensym("input"), ctype, ASTTuple());
// Compile all expressions into it
- CValue* val = NULL;
+ Value* val = NULL;
for (list< pair<SExp, AST*> >::const_iterator i = exprs.begin(); i != exprs.end(); ++i)
- val = cenv.compile(i->second);
+ val = LLVal(cenv.compile(i->second));
// Finish function
cenv.engine.builder.CreateRet(val);
- cenv.optimise(*f);
+ cenv.optimise(f);
string resultStr = call(resultType, cenv.engine.engine->getPointerToFunction(f));
out << resultStr << " : " << resultType->str() << endl;
@@ -555,16 +579,16 @@ repl(CEnv& cenv)
// Create anonymous function to insert code into
Function* f = compileFunction(cenv, cenv.gensym("_repl"), bodyT->type(), ASTTuple());
try {
- CValue* retVal = cenv.compile(body);
+ Value* retVal = LLVal(cenv.compile(body));
cenv.engine.builder.CreateRet(retVal); // Finish function
- cenv.optimise(*f);
+ cenv.optimise(f);
} catch (Error& e) {
f->eraseFromParent(); // Error reading body, remove function
throw e;
}
out << call(bodyT, cenv.engine.engine->getPointerToFunction(f));
} else {
- CValue* val = cenv.compile(body);
+ CValue val = cenv.compile(body);
out << "; " << val;
}
out << " : " << cenv.tenv.type(body)->str() << endl;
@@ -611,7 +635,7 @@ main(int argc, char** argv)
cenv.tenv.def(penv.sym("Float"), new AType(penv.sym("Float"), Type::FloatTy));
// Host provided allocation primitive prototypes
- std::vector<const CType*> argsT(1, Type::Int32Ty);
+ std::vector<const Type*> argsT(1, Type::Int32Ty);
FunctionType* funcT = FunctionType::get(PointerType::get(Type::Int8Ty, 0), argsT, false);
cenv.alloc = Function::Create(funcT, Function::ExternalLinkage, "malloc", engine.module);