diff options
Diffstat (limited to 'llvm.cpp')
-rw-r--r-- | llvm.cpp | 179 |
1 files changed, 125 insertions, 54 deletions
@@ -46,10 +46,11 @@ struct LLVMEngine { }; static const Type* -lltype(AType* t) +lltype(const AType* t) { switch (t->kind) { case AType::VAR: + throw Error((format("non-compilable type `%1%'") % t->str()).str(), t->loc); return NULL; case AType::PRIM: if (t->at(0)->str() == "Bool") return Type::Int1Ty; @@ -117,7 +118,7 @@ CValue CEnv::compile(AST* obj) { CValue* v = vals.ref(obj); - return (v) ? *v : vals.def(obj, obj->compile(*this)); + return (v && *v) ? *v : vals.def(obj, obj->compile(*this)); } void @@ -138,7 +139,7 @@ CEnv::write(std::ostream& os) template<> CValue \ ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \ template<> void \ -ALiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); } +ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) { c.constrain(tenv, this, tenv.named(NAME)); } /// Literal template instantiations LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true)) @@ -146,14 +147,15 @@ LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val)) LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false)) static Function* -compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& prot, +compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& protT, const vector<string> argNames=vector<string>()) { Function::LinkageTypes linkage = Function::ExternalLinkage; vector<const Type*> cprot; - for (size_t i = 0; i < prot.size(); ++i) { - AType* at = cenv.tenv.type(prot.at(i)); + for (size_t i = 0; i < protT.size(); ++i) { + AType* at = dynamic_cast<AType*>(protT.at(i)); + if (!at) throw Error("function parameter type isn't"); if (!lltype(at)) throw Error("function parameter is untyped"); cprot.push_back(lltype(at)); } @@ -170,11 +172,8 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu // Set argument names in generated code Function::arg_iterator a = f->arg_begin(); if (!argNames.empty()) - for (size_t i = 0; i != prot.size(); ++a, ++i) + for (size_t i = 0; i != protT.size(); ++a, ++i) a->setName(argNames.at(i)); - else - for (size_t i = 0; i != prot.size(); ++a, ++i) - a->setName(prot.at(i)->str()); BasicBlock* bb = BasicBlock::Create("entry", f); llengine(cenv)->builder.SetInsertPoint(bb); @@ -187,53 +186,110 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu * Code Generation * ***************************************************************************/ -void -ASymbol::lift(CEnv& cenv) -{ - if (!cenv.code.ref(this)) - throw Error((string("undefined symbol `") + cppstr + "'").c_str(), loc); -} - CValue ASymbol::compile(CEnv& cenv) { - return cenv.compile(*cenv.code.ref(this)); + return cenv.vals.ref(this); } void AClosure::lift(CEnv& cenv) { - AType* type = cenv.tenv.type(this); + AType* type = cenv.type(this); if (!type->concrete() || funcs.find(type)) return; - cenv.push(); - // Write function declaration string name = this->name == "" ? cenv.gensym("_fn") : this->name; - Function* f = compileFunction(cenv, name, lltype(cenv.tenv.type(at(2))), *prot()); + ATuple* protT = dynamic_cast<ATuple*>(type->at(1)); + assert(protT); + Function* f = compileFunction(cenv, name, + lltype(dynamic_cast<AType*>(type->at(type->size() - 1))), + *protT); + + cenv.push(); + Subst oldSubst = cenv.tsubst; + cenv.tsubst = Subst::compose(cenv.tsubst, *subst); // Bind argument values in CEnv vector<Value*> args; const_iterator p = prot()->begin(); - for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) + size_t i = 0; + for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) { + cenv.tenv.def(*p, dynamic_cast<AType*>(protT->at(i++))); cenv.vals.def(dynamic_cast<ASymbol*>(*p), &*a); + } // Write function body try { - cenv.precompile(this, f); // Define our value first for recursion + // Define value first for recursion + cenv.precompile(this, f); + funcs.push_back(make_pair(type, f)); + CValue retVal = cenv.compile(at(2)); llengine(cenv)->builder.CreateRet(LLVal(retVal)); // Finish function cenv.optimise(LLFunc(f)); - funcs.push_back(make_pair(type, f)); + } catch (Error& e) { f->eraseFromParent(); // Error reading body, remove function + cenv.pop(); throw e; } - + cenv.tsubst = oldSubst; cenv.pop(); } +void +AClosure::liftPoly(CEnv& cenv, const vector<AType*>& argsT) +{ + if (type->concrete()) + return; + + throw Error("No polymorphism"); + +#if 0 + //Subst tsubst; + assert(argsT.size() == prot()->size()); + for (size_t i = 0; i < argsT.size(); ++i) { + cenv.err << " " << argsT.at(i)->str(); + //tsubst[*cenv.tenv.ref(prot()->at(i))] = argsT.at(i); + } + cenv.err << endl; +#endif +} + +CValue +AClosure::compile(CEnv& cenv) +{ + /* + cenv.err << "***********************************************" << endl; + cenv.err << cenv.type(this) << endl; + + cenv.err << "COMPILING FOR TYPE:"; + Subst tsubst; + assert(cenv.code.front().size() == prot()->size()); + for (size_t i = 0; i < cenv.code.front().size(); ++i) { + cenv.err << " (" << cenv.type(prot()->at(i))->str() + << " -> " << cenv.type(cenv.code.front().at(i).second)->str() << ")"; + tsubst[cenv.tenv.types[prot()->at(i)]] = + cenv.type(cenv.code.front().at(i).second); + } + cenv.err << endl; + + Subst subst = Subst::compose(tsubst, cenv.tsubst); + AType* concreteType = subst.apply(type); + if (!concreteType->concrete()) + throw Error("compiled function has non-concrete type", loc); + + cenv.err << "*********** CONCRETE TYPE: " << concreteType->str() << endl; + */ + + //CValue ret = funcs.find(concreteType); + //cenv.err << "VALUE FOR TYPE " << concreteType->str() << " : " << ret << endl; + //return ret; + return NULL; +} + template<typename T> T checked_cast(AST* ast) @@ -250,27 +306,23 @@ AST* maybeLookup(CEnv& cenv, AST* ast) { ASymbol* s = dynamic_cast<ASymbol*>(ast); - if (s) { - AST** val = cenv.code.ref(s); - if (val) return *val; - } + if (s) + return cenv.code.deref(s->addr); return ast; } -CValue -AClosure::compile(CEnv& cenv) -{ - return funcs.find(cenv.tenv.type(this)); -} - void ACall::lift(CEnv& cenv) { AClosure* c = dynamic_cast<AClosure*>(maybeLookup(cenv, at(0))); + vector<AType*> argsT; + // Lift arguments - for (size_t i = 1; i < size(); ++i) + for (size_t i = 1; i < size(); ++i) { at(i)->lift(cenv); + argsT.push_back(cenv.type(at(i))); + } if (!c) return; // Primitive @@ -284,15 +336,30 @@ ACall::lift(CEnv& cenv) for (size_t i = 1; i < size(); ++i) cenv.code.def(checked_cast<ASymbol*>(c->prot()->at(i-1)), at(i)); - c->lift(cenv); // Lift called closure + c->liftPoly(cenv, argsT); // Lift called closure cenv.pop(); // Restore environment } CValue ACall::compile(CEnv& cenv) { - AST* c = maybeLookup(cenv, at(0)); - Function* f = dynamic_cast<Function*>(LLVal(cenv.compile(c))); + AClosure* c = dynamic_cast<AClosure*>(maybeLookup(cenv, at(0))); + + if (!c) return NULL; // Primitive + + if (c->prot()->size() < size() - 1) + throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc); + if (c->prot()->size() > size() - 1) + throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc); + + AType* protT = new AType(loc, NULL); + for (size_t i = 1; i < size(); ++i) + protT->push_back(cenv.type(at(i))); + + AType* polyT = c->type; + AType* fnT = new AType(loc, cenv.penv.sym("Fn"), protT, polyT->at(2), 0); + + Function* f = (Function*)c->funcs.find(fnT); if (!f) throw Error("callee failed to compile", loc); vector<Value*> params(size() - 1); @@ -305,7 +372,7 @@ ACall::compile(CEnv& cenv) void ADefinition::lift(CEnv& cenv) { - if (cenv.code.ref(checked_cast<ASymbol*>(at(1)))) + if (cenv.code.lookup(checked_cast<ASymbol*>(at(1)))) throw Error(string("`") + at(1)->str() + "' redefined", loc); cenv.code.def((ASymbol*)at(1), at(2)); // Define first for recursion at(2)->lift(cenv); @@ -353,7 +420,7 @@ AIf::compile(CEnv& cenv) // Emit merge block (Phi node) parent->getBasicBlockList().push_back(mergeBB); llengine(cenv)->builder.SetInsertPoint(mergeBB); - PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.tenv.type(this)), "ifval"); + PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.type(this)), "ifval"); FOREACH(Branches::iterator, i, branches) pn->addIncoming(i->first, i->second); @@ -366,7 +433,7 @@ APrimitive::compile(CEnv& cenv) { Value* a = LLVal(cenv.compile(at(1))); Value* b = LLVal(cenv.compile(at(2))); - bool isInt = cenv.tenv.type(at(1))->str() == "Int"; + bool isInt = cenv.type(at(1))->str() == "Int"; const string n = dynamic_cast<ASymbol*>(at(0))->str(); // Binary arithmetic operations @@ -407,9 +474,9 @@ APrimitive::compile(CEnv& cenv) AType* AConsCall::functionType(CEnv& cenv) { - ATuple* protTypes = new ATuple(loc, cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), 0); + ATuple* protTypes = new ATuple(loc, cenv.type(at(1)), cenv.type(at(2)), 0); AType* cellType = new AType(loc, - cenv.penv.sym("Pair"), cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), 0); + cenv.penv.sym("Pair"), cenv.type(at(1)), cenv.type(at(2)), 0); return new AType(at(0)->loc, cenv.penv.sym("Fn"), protTypes, cellType, 0); } @@ -427,7 +494,7 @@ AConsCall::lift(CEnv& cenv) vector<const Type*> types; size_t sz = 0; for (size_t i = 1; i < size(); ++i) { - const Type* t = lltype(cenv.tenv.type(at(i))); + const Type* t = lltype(cenv.type(at(i))); types.push_back(t); sz += t->getPrimitiveSizeInBits(); } @@ -520,15 +587,16 @@ eval(CEnv& cenv, const string& name, istream& is) list< pair<SExp, AST*> > exprs; Cursor cursor(name); try { + Constraints c; while (true) { SExp exp = readExpression(cursor, is); if (exp.type == SExp::LIST && exp.list.empty()) break; result = cenv.penv.parse(exp); // Parse input - result->constrain(cenv.tenv); // Constrain types - cenv.tenv.solve(); // Solve and apply type constraints - resultType = cenv.tenv.type(result); + result->constrain(cenv.tenv, c); // Constrain types + cenv.tsubst = TEnv::unify(c); // Solve type constraints + resultType = cenv.type(result); result->lift(cenv); // Lift functions exprs.push_back(make_pair(exp, result)); } @@ -562,24 +630,27 @@ eval(CEnv& cenv, const string& name, istream& is) int repl(CEnv& cenv) { + Constraints c; while (1) { cenv.out << "() "; cenv.out.flush(); Cursor cursor("(stdin)"); try { + SExp exp = readExpression(cursor, std::cin); if (exp.type == SExp::LIST && exp.list.empty()) break; AST* body = cenv.penv.parse(exp); // Parse input - body->constrain(cenv.tenv); // Constrain types - cenv.tenv.solve(); // Solve and apply type constraints + body->constrain(cenv.tenv, c); // Constrain types + + cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints - AType* bodyT = cenv.tenv.type(body); + AType* bodyT = cenv.type(body); if (!bodyT) throw Error("call to untyped body", cursor); body->lift(cenv); - + if (lltype(bodyT)) { // Create anonymous function to insert code into Function* f = compileFunction(cenv, cenv.gensym("_repl"), lltype(bodyT), ATuple()); @@ -595,7 +666,7 @@ repl(CEnv& cenv) } else { cenv.out << "; " << cenv.compile(body); } - cenv.out << " : " << cenv.tenv.type(body) << endl; + cenv.out << " : " << cenv.type(body) << endl; } catch (Error& e) { cenv.err << e.what() << endl; } |