diff options
Diffstat (limited to 'src/llvm.cpp')
-rw-r--r-- | src/llvm.cpp | 104 |
1 files changed, 86 insertions, 18 deletions
diff --git a/src/llvm.cpp b/src/llvm.cpp index 3f48dec..cc4bf47 100644 --- a/src/llvm.cpp +++ b/src/llvm.cpp @@ -58,18 +58,21 @@ struct LLVMEngine : public Engine { CFunc startFn(CEnv& cenv, const string& name, const ATuple* args, const ATuple* type); void pushFnArgs(CEnv& cenv, const ATuple* prot, const ATuple* type, CFunc f); - void finishFn(CEnv& cenv, CFunc f, CVal ret); + void finishFn(CEnv& cenv, CFunc f, CVal ret, const AST* retT); void eraseFn(CEnv& cenv, CFunc f); - CVal compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector<CVal>& args); - CVal compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector<CVal>& fields); - CVal compileDot(CEnv& cenv, CVal tup, int32_t index); - CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t); - CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v); - CVal compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse); - CVal compileLiteral(CEnv& cenv, const AST* lit); - CVal compilePrimitive(CEnv& cenv, const ATuple* prim); - CVal compileString(CEnv& cenv, const char* str); + CVal compileCall(CEnv& cenv, CFunc f, const ATuple* funcT, const vector<CVal>& args); + CVal compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector<CVal>& fields); + CVal compileDot(CEnv& cenv, CVal tup, int32_t index); + CVal compileGlobalSet(CEnv& cenv, const string& s, CVal v, const AST* t); + CVal compileGlobalGet(CEnv& cenv, const string& s, CVal v); + CVal compileIf(CEnv& cenv, const AST* cond, const AST* then, const AST* aelse); + CVal compileLiteral(CEnv& cenv, const AST* lit); + CVal compilePrimitive(CEnv& cenv, const ATuple* prim); + CVal compileString(CEnv& cenv, const char* str); + CType compileType(CEnv& cenv, const char* name, const AST* exp); + + CType objectType(CEnv& cenv); void writeModule(CEnv& cenv, std::ostream& os); @@ -91,6 +94,10 @@ private: IRBuilder<> builder; Function* alloc; FunctionPassManager* opt; + CType objectT; + + typedef std::map<const std::string, CType> CTypes; + CTypes compiledTypes; unsigned labelIndex; }; @@ -114,9 +121,17 @@ LLVMEngine::LLVMEngine() // Declare host provided allocation primitive std::vector<const Type*> argsT(1, Type::getInt32Ty(context)); // unsigned size - FunctionType* funcT = FunctionType::get(PointerType::get(Type::getInt8Ty(context), 0), argsT, false); - alloc = Function::Create(funcT, Function::ExternalLinkage, - "resp_gc_allocate", module); + FunctionType* funcT = FunctionType::get( + PointerType::get(Type::getInt8Ty(context), 0), argsT, false); + alloc = Function::Create( + funcT, Function::ExternalLinkage, "resp_gc_allocate", module); + + // Build Object type (tag only, binary compatible with any constructed thing) + vector<const Type*> ctypes; + ctypes.push_back(PointerType::get(Type::getInt8Ty(context), NULL)); // RTTI + StructType* cObjectT = StructType::get(context, ctypes, false); + module->addTypeName("Object", cObjectT); + objectT = cObjectT; } LLVMEngine::~LLVMEngine() @@ -141,6 +156,7 @@ LLVMEngine::llType(const AST* t) if (sym == "Float") return Type::getFloatTy(context); if (sym == "String") return PointerType::get(Type::getInt8Ty(context), NULL); if (sym == "Symbol") return PointerType::get(Type::getInt8Ty(context), NULL); + if (sym == "Expr") return PointerType::get(Type::getInt8Ty(context), NULL); } else if (is_form(t, "Fn")) { ATuple::const_iterator i = t->as_tuple()->begin(); const ATuple* protT = (*++i)->to_tuple(); @@ -209,9 +225,16 @@ LLVMEngine::compileCons(CEnv& cenv, const ATuple* type, CVal rtti, const vector< if (rtti) builder.CreateStore((Value*)rtti, builder.CreateStructGEP(structPtr, 0, "rtti")); size_t i = 1; - for (vector<CVal>::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i) { - builder.CreateStore(llVal(*f), - builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str())); + ATuple::const_iterator t = type->iter_at(1); + for (vector<CVal>::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i, ++t) { + Value* val = llVal(*f); + Value* field = builder.CreateStructGEP( + structPtr, i, (format("tup%1%") % i).str().c_str()); + + if ((*t)->to_tuple()) + val = builder.CreateBitCast(val, llType(*t), "objPtr"); + + builder.CreateStore(val, field); } return structPtr; @@ -245,6 +268,47 @@ LLVMEngine::compileString(CEnv& cenv, const char* str) return builder.CreateGlobalStringPtr(str); } +CType +LLVMEngine::compileType(CEnv& cenv, const char* name, const AST* expr) +{ + CTypes::const_iterator i = compiledTypes.find(name); + if (i != compiledTypes.end()) + return i->second; + + const ATuple* const tup = expr->as_tuple(); + vector<const Type*> ctypes; + ctypes.push_back(PointerType::get(Type::getInt8Ty(context), NULL)); // RTTI + for (ATuple::const_iterator i = tup->iter_at(1); i != tup->end(); ++i) { + const ATuple* tup = (*i)->to_tuple(); + const Type* lt = (tup) + ? (const Type*)compileType(cenv, tup->fst()->as_symbol()->sym(), *i) + : llType(*i); + if (!lt) + return NULL; + ctypes.push_back(lt); + } + + Type* structT = StructType::get(context, ctypes, false); + + // Tell LLVM opaqueT and structT are the same (for recursive types) + //PATypeHolder opaqueT = OpaqueType::get(context); + //((OpaqueType*)opaqueT.get())->refineAbstractTypeTo(structT); + //structT = cast<StructType>(opaqueT.get()); // updated potentially invalidated structT + + Type* ret = PointerType::get(structT, 0); + module->addTypeName(name, structT); + + compiledTypes.insert(make_pair(name, ret)); + + return ret; +} + +CType +LLVMEngine::objectType(CEnv& cenv) +{ + return objectT; +} + CFunc LLVMEngine::startFn( CEnv& cenv, const std::string& name, const ATuple* args, const ATuple* type) @@ -304,9 +368,13 @@ LLVMEngine::pushFnArgs(CEnv& cenv, const ATuple* prot, const ATuple* type, CFunc } void -LLVMEngine::finishFn(CEnv& cenv, CFunc f, CVal ret) +LLVMEngine::finishFn(CEnv& cenv, CFunc f, CVal ret, const AST* retT) { - builder.CreateRet(llVal(ret)); + if (retT->str() == "Nothing") + builder.CreateRetVoid(); + else + builder.CreateRet(builder.CreateBitCast(llVal(ret), llType(retT), "ret")); + if (verifyFunction(*static_cast<Function*>(f), llvm::PrintMessageAction)) { module->dump(); throw Error(Cursor(), "Broken module"); |