aboutsummaryrefslogtreecommitdiffstats
path: root/llvm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm.cpp')
-rw-r--r--llvm.cpp179
1 files changed, 125 insertions, 54 deletions
diff --git a/llvm.cpp b/llvm.cpp
index baaeb9e..32577f9 100644
--- a/llvm.cpp
+++ b/llvm.cpp
@@ -46,10 +46,11 @@ struct LLVMEngine {
};
static const Type*
-lltype(AType* t)
+lltype(const AType* t)
{
switch (t->kind) {
case AType::VAR:
+ throw Error((format("non-compilable type `%1%'") % t->str()).str(), t->loc);
return NULL;
case AType::PRIM:
if (t->at(0)->str() == "Bool") return Type::Int1Ty;
@@ -117,7 +118,7 @@ CValue
CEnv::compile(AST* obj)
{
CValue* v = vals.ref(obj);
- return (v) ? *v : vals.def(obj, obj->compile(*this));
+ return (v && *v) ? *v : vals.def(obj, obj->compile(*this));
}
void
@@ -138,7 +139,7 @@ CEnv::write(std::ostream& os)
template<> CValue \
ALiteral<CT>::compile(CEnv& cenv) { return (COMPILED); } \
template<> void \
-ALiteral<CT>::constrain(TEnv& tenv) const { tenv.constrain(this, tenv.named(NAME)); }
+ALiteral<CT>::constrain(TEnv& tenv, Constraints& c) { c.constrain(tenv, this, tenv.named(NAME)); }
/// Literal template instantiations
LITERAL(int32_t, "Int", ConstantInt::get(Type::Int32Ty, val, true))
@@ -146,14 +147,15 @@ LITERAL(float, "Float", ConstantFP::get(Type::FloatTy, val))
LITERAL(bool, "Bool", ConstantInt::get(Type::Int1Ty, val, false))
static Function*
-compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& prot,
+compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATuple& protT,
const vector<string> argNames=vector<string>())
{
Function::LinkageTypes linkage = Function::ExternalLinkage;
vector<const Type*> cprot;
- for (size_t i = 0; i < prot.size(); ++i) {
- AType* at = cenv.tenv.type(prot.at(i));
+ for (size_t i = 0; i < protT.size(); ++i) {
+ AType* at = dynamic_cast<AType*>(protT.at(i));
+ if (!at) throw Error("function parameter type isn't");
if (!lltype(at)) throw Error("function parameter is untyped");
cprot.push_back(lltype(at));
}
@@ -170,11 +172,8 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu
// Set argument names in generated code
Function::arg_iterator a = f->arg_begin();
if (!argNames.empty())
- for (size_t i = 0; i != prot.size(); ++a, ++i)
+ for (size_t i = 0; i != protT.size(); ++a, ++i)
a->setName(argNames.at(i));
- else
- for (size_t i = 0; i != prot.size(); ++a, ++i)
- a->setName(prot.at(i)->str());
BasicBlock* bb = BasicBlock::Create("entry", f);
llengine(cenv)->builder.SetInsertPoint(bb);
@@ -187,53 +186,110 @@ compileFunction(CEnv& cenv, const std::string& name, const Type* retT, const ATu
* Code Generation *
***************************************************************************/
-void
-ASymbol::lift(CEnv& cenv)
-{
- if (!cenv.code.ref(this))
- throw Error((string("undefined symbol `") + cppstr + "'").c_str(), loc);
-}
-
CValue
ASymbol::compile(CEnv& cenv)
{
- return cenv.compile(*cenv.code.ref(this));
+ return cenv.vals.ref(this);
}
void
AClosure::lift(CEnv& cenv)
{
- AType* type = cenv.tenv.type(this);
+ AType* type = cenv.type(this);
if (!type->concrete() || funcs.find(type))
return;
- cenv.push();
-
// Write function declaration
string name = this->name == "" ? cenv.gensym("_fn") : this->name;
- Function* f = compileFunction(cenv, name, lltype(cenv.tenv.type(at(2))), *prot());
+ ATuple* protT = dynamic_cast<ATuple*>(type->at(1));
+ assert(protT);
+ Function* f = compileFunction(cenv, name,
+ lltype(dynamic_cast<AType*>(type->at(type->size() - 1))),
+ *protT);
+
+ cenv.push();
+ Subst oldSubst = cenv.tsubst;
+ cenv.tsubst = Subst::compose(cenv.tsubst, *subst);
// Bind argument values in CEnv
vector<Value*> args;
const_iterator p = prot()->begin();
- for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p)
+ size_t i = 0;
+ for (Function::arg_iterator a = f->arg_begin(); a != f->arg_end(); ++a, ++p) {
+ cenv.tenv.def(*p, dynamic_cast<AType*>(protT->at(i++)));
cenv.vals.def(dynamic_cast<ASymbol*>(*p), &*a);
+ }
// Write function body
try {
- cenv.precompile(this, f); // Define our value first for recursion
+ // Define value first for recursion
+ cenv.precompile(this, f);
+ funcs.push_back(make_pair(type, f));
+
CValue retVal = cenv.compile(at(2));
llengine(cenv)->builder.CreateRet(LLVal(retVal)); // Finish function
cenv.optimise(LLFunc(f));
- funcs.push_back(make_pair(type, f));
+
} catch (Error& e) {
f->eraseFromParent(); // Error reading body, remove function
+ cenv.pop();
throw e;
}
-
+ cenv.tsubst = oldSubst;
cenv.pop();
}
+void
+AClosure::liftPoly(CEnv& cenv, const vector<AType*>& argsT)
+{
+ if (type->concrete())
+ return;
+
+ throw Error("No polymorphism");
+
+#if 0
+ //Subst tsubst;
+ assert(argsT.size() == prot()->size());
+ for (size_t i = 0; i < argsT.size(); ++i) {
+ cenv.err << " " << argsT.at(i)->str();
+ //tsubst[*cenv.tenv.ref(prot()->at(i))] = argsT.at(i);
+ }
+ cenv.err << endl;
+#endif
+}
+
+CValue
+AClosure::compile(CEnv& cenv)
+{
+ /*
+ cenv.err << "***********************************************" << endl;
+ cenv.err << cenv.type(this) << endl;
+
+ cenv.err << "COMPILING FOR TYPE:";
+ Subst tsubst;
+ assert(cenv.code.front().size() == prot()->size());
+ for (size_t i = 0; i < cenv.code.front().size(); ++i) {
+ cenv.err << " (" << cenv.type(prot()->at(i))->str()
+ << " -> " << cenv.type(cenv.code.front().at(i).second)->str() << ")";
+ tsubst[cenv.tenv.types[prot()->at(i)]] =
+ cenv.type(cenv.code.front().at(i).second);
+ }
+ cenv.err << endl;
+
+ Subst subst = Subst::compose(tsubst, cenv.tsubst);
+ AType* concreteType = subst.apply(type);
+ if (!concreteType->concrete())
+ throw Error("compiled function has non-concrete type", loc);
+
+ cenv.err << "*********** CONCRETE TYPE: " << concreteType->str() << endl;
+ */
+
+ //CValue ret = funcs.find(concreteType);
+ //cenv.err << "VALUE FOR TYPE " << concreteType->str() << " : " << ret << endl;
+ //return ret;
+ return NULL;
+}
+
template<typename T>
T
checked_cast(AST* ast)
@@ -250,27 +306,23 @@ AST*
maybeLookup(CEnv& cenv, AST* ast)
{
ASymbol* s = dynamic_cast<ASymbol*>(ast);
- if (s) {
- AST** val = cenv.code.ref(s);
- if (val) return *val;
- }
+ if (s)
+ return cenv.code.deref(s->addr);
return ast;
}
-CValue
-AClosure::compile(CEnv& cenv)
-{
- return funcs.find(cenv.tenv.type(this));
-}
-
void
ACall::lift(CEnv& cenv)
{
AClosure* c = dynamic_cast<AClosure*>(maybeLookup(cenv, at(0)));
+ vector<AType*> argsT;
+
// Lift arguments
- for (size_t i = 1; i < size(); ++i)
+ for (size_t i = 1; i < size(); ++i) {
at(i)->lift(cenv);
+ argsT.push_back(cenv.type(at(i)));
+ }
if (!c) return; // Primitive
@@ -284,15 +336,30 @@ ACall::lift(CEnv& cenv)
for (size_t i = 1; i < size(); ++i)
cenv.code.def(checked_cast<ASymbol*>(c->prot()->at(i-1)), at(i));
- c->lift(cenv); // Lift called closure
+ c->liftPoly(cenv, argsT); // Lift called closure
cenv.pop(); // Restore environment
}
CValue
ACall::compile(CEnv& cenv)
{
- AST* c = maybeLookup(cenv, at(0));
- Function* f = dynamic_cast<Function*>(LLVal(cenv.compile(c)));
+ AClosure* c = dynamic_cast<AClosure*>(maybeLookup(cenv, at(0)));
+
+ if (!c) return NULL; // Primitive
+
+ if (c->prot()->size() < size() - 1)
+ throw Error((format("too many arguments to function `%1%'") % at(0)->str()).str(), loc);
+ if (c->prot()->size() > size() - 1)
+ throw Error((format("too few arguments to function `%1%'") % at(0)->str()).str(), loc);
+
+ AType* protT = new AType(loc, NULL);
+ for (size_t i = 1; i < size(); ++i)
+ protT->push_back(cenv.type(at(i)));
+
+ AType* polyT = c->type;
+ AType* fnT = new AType(loc, cenv.penv.sym("Fn"), protT, polyT->at(2), 0);
+
+ Function* f = (Function*)c->funcs.find(fnT);
if (!f) throw Error("callee failed to compile", loc);
vector<Value*> params(size() - 1);
@@ -305,7 +372,7 @@ ACall::compile(CEnv& cenv)
void
ADefinition::lift(CEnv& cenv)
{
- if (cenv.code.ref(checked_cast<ASymbol*>(at(1))))
+ if (cenv.code.lookup(checked_cast<ASymbol*>(at(1))))
throw Error(string("`") + at(1)->str() + "' redefined", loc);
cenv.code.def((ASymbol*)at(1), at(2)); // Define first for recursion
at(2)->lift(cenv);
@@ -353,7 +420,7 @@ AIf::compile(CEnv& cenv)
// Emit merge block (Phi node)
parent->getBasicBlockList().push_back(mergeBB);
llengine(cenv)->builder.SetInsertPoint(mergeBB);
- PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.tenv.type(this)), "ifval");
+ PHINode* pn = llengine(cenv)->builder.CreatePHI(lltype(cenv.type(this)), "ifval");
FOREACH(Branches::iterator, i, branches)
pn->addIncoming(i->first, i->second);
@@ -366,7 +433,7 @@ APrimitive::compile(CEnv& cenv)
{
Value* a = LLVal(cenv.compile(at(1)));
Value* b = LLVal(cenv.compile(at(2)));
- bool isInt = cenv.tenv.type(at(1))->str() == "Int";
+ bool isInt = cenv.type(at(1))->str() == "Int";
const string n = dynamic_cast<ASymbol*>(at(0))->str();
// Binary arithmetic operations
@@ -407,9 +474,9 @@ APrimitive::compile(CEnv& cenv)
AType*
AConsCall::functionType(CEnv& cenv)
{
- ATuple* protTypes = new ATuple(loc, cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), 0);
+ ATuple* protTypes = new ATuple(loc, cenv.type(at(1)), cenv.type(at(2)), 0);
AType* cellType = new AType(loc,
- cenv.penv.sym("Pair"), cenv.tenv.type(at(1)), cenv.tenv.type(at(2)), 0);
+ cenv.penv.sym("Pair"), cenv.type(at(1)), cenv.type(at(2)), 0);
return new AType(at(0)->loc, cenv.penv.sym("Fn"), protTypes, cellType, 0);
}
@@ -427,7 +494,7 @@ AConsCall::lift(CEnv& cenv)
vector<const Type*> types;
size_t sz = 0;
for (size_t i = 1; i < size(); ++i) {
- const Type* t = lltype(cenv.tenv.type(at(i)));
+ const Type* t = lltype(cenv.type(at(i)));
types.push_back(t);
sz += t->getPrimitiveSizeInBits();
}
@@ -520,15 +587,16 @@ eval(CEnv& cenv, const string& name, istream& is)
list< pair<SExp, AST*> > exprs;
Cursor cursor(name);
try {
+ Constraints c;
while (true) {
SExp exp = readExpression(cursor, is);
if (exp.type == SExp::LIST && exp.list.empty())
break;
result = cenv.penv.parse(exp); // Parse input
- result->constrain(cenv.tenv); // Constrain types
- cenv.tenv.solve(); // Solve and apply type constraints
- resultType = cenv.tenv.type(result);
+ result->constrain(cenv.tenv, c); // Constrain types
+ cenv.tsubst = TEnv::unify(c); // Solve type constraints
+ resultType = cenv.type(result);
result->lift(cenv); // Lift functions
exprs.push_back(make_pair(exp, result));
}
@@ -562,24 +630,27 @@ eval(CEnv& cenv, const string& name, istream& is)
int
repl(CEnv& cenv)
{
+ Constraints c;
while (1) {
cenv.out << "() ";
cenv.out.flush();
Cursor cursor("(stdin)");
try {
+
SExp exp = readExpression(cursor, std::cin);
if (exp.type == SExp::LIST && exp.list.empty())
break;
AST* body = cenv.penv.parse(exp); // Parse input
- body->constrain(cenv.tenv); // Constrain types
- cenv.tenv.solve(); // Solve and apply type constraints
+ body->constrain(cenv.tenv, c); // Constrain types
+
+ cenv.tsubst = Subst::compose(cenv.tsubst, TEnv::unify(c)); // Solve type constraints
- AType* bodyT = cenv.tenv.type(body);
+ AType* bodyT = cenv.type(body);
if (!bodyT) throw Error("call to untyped body", cursor);
body->lift(cenv);
-
+
if (lltype(bodyT)) {
// Create anonymous function to insert code into
Function* f = compileFunction(cenv, cenv.gensym("_repl"), lltype(bodyT), ATuple());
@@ -595,7 +666,7 @@ repl(CEnv& cenv)
} else {
cenv.out << "; " << cenv.compile(body);
}
- cenv.out << " : " << cenv.tenv.type(body) << endl;
+ cenv.out << " : " << cenv.type(body) << endl;
} catch (Error& e) {
cenv.err << e.what() << endl;
}