aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/c.cpp11
-rw-r--r--src/compile.cpp34
-rw-r--r--src/constrain.cpp79
-rw-r--r--src/cps.cpp12
-rw-r--r--src/llvm.cpp74
-rw-r--r--src/parse.cpp26
-rw-r--r--src/pprint.cpp9
-rw-r--r--src/repl.cpp3
-rw-r--r--src/resp.hpp53
-rw-r--r--src/unify.cpp4
-rwxr-xr-xtest.sh1
-rw-r--r--test/match.resp14
12 files changed, 267 insertions, 53 deletions
diff --git a/src/c.cpp b/src/c.cpp
index 6bd3a2d..e8f7c2a 100644
--- a/src/c.cpp
+++ b/src/c.cpp
@@ -154,12 +154,13 @@ struct CEngine : public Engine {
return varname;
}
- CVal compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields);
+ CVal compileTup(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields);
CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
CVal compileLiteral(CEnv& cenv, const AST* lit);
CVal compileString(CEnv& cenv, const char* str);
CVal compilePrimitive(CEnv& cenv, const APrimitive* prim);
CVal compileIf(CEnv& cenv, const AIf* aif);
+ CVal compileMatch(CEnv& cenv, const AMatch* match);
CVal compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal val);
CVal getGlobal(CEnv& cenv, const string& sym, CVal val);
@@ -186,7 +187,7 @@ resp_new_c_engine()
***************************************************************************/
CVal
-CEngine::compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields)
+CEngine::compileTup(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields)
{
return NULL;
}
@@ -261,6 +262,12 @@ CEngine::compileIf(CEnv& cenv, const AIf* aif)
}
CVal
+CEngine::compileMatch(CEnv& cenv, const AMatch* match)
+{
+ return NULL;
+}
+
+CVal
CEngine::compilePrimitive(CEnv& cenv, const APrimitive* prim)
{
APrimitive::const_iterator i = prim->begin();
diff --git a/src/compile.cpp b/src/compile.cpp
index f954b8e..e2d306a 100644
--- a/src/compile.cpp
+++ b/src/compile.cpp
@@ -130,21 +130,33 @@ AIf::compile(CEnv& cenv) const throw()
CVal
ACons::compile(CEnv& cenv) const throw()
{
- const AType* type = cenv.type(this);
- vector<CVal> fields;
- for (const_iterator i = begin() + 1; i != end(); ++i)
- fields.push_back((*i)->compile(cenv));
- return cenv.engine()->compileTup(cenv, type, fields);
+ return ATuple::compile(cenv);
}
CVal
ATuple::compile(CEnv& cenv) const throw()
{
- const AType* type = cenv.type(this);
+ AType* type = tup<AType>(loc, const_cast<ASymbol*>(head()->as<const ASymbol*>()), 0);
vector<CVal> fields;
- for (const_iterator i = begin(); i != end(); ++i)
+ for (const_iterator i = begin() + 1; i != end(); ++i) {
+ type->push_back(const_cast<AType*>(cenv.type(*i)));
fields.push_back((*i)->compile(cenv));
- return cenv.engine()->compileTup(cenv, type, fields);
+ }
+ return cenv.engine()->compileTup(cenv, type, type->compile(cenv), fields);
+}
+
+CVal
+AType::compile(CEnv& cenv) const throw()
+{
+ const ASymbol* sym = head()->as<const ASymbol*>();
+ CVal* existing = cenv.vals.ref(sym);
+ if (existing) {
+ return *existing;
+ } else {
+ CVal compiled = cenv.engine()->compileString(cenv, (string("__T_") + head()->str()).c_str());
+ cenv.vals.def(sym, compiled);
+ return compiled;
+ }
}
CVal
@@ -162,3 +174,9 @@ APrimitive::compile(CEnv& cenv) const throw()
{
return cenv.engine()->compilePrimitive(cenv, this);
}
+
+CVal
+AMatch::compile(CEnv& cenv) const throw()
+{
+ return cenv.engine()->compileMatch(cenv, this);
+}
diff --git a/src/constrain.cpp b/src/constrain.cpp
index e5a8dc4..ef8e3bf 100644
--- a/src/constrain.cpp
+++ b/src/constrain.cpp
@@ -185,15 +185,84 @@ AIf::constrain(TEnv& tenv, Constraints& c) const throw(Error)
}
void
+AMatch::constrain(TEnv& tenv, Constraints& c) const throw(Error)
+{
+ THROW_IF(size() < 5, loc, "`match' requires at least 4 arguments");
+ const AST* matchee = (*(begin() + 1));
+ const AType* retT = tenv.var();
+ const AType* matcheeT = NULL;// = tup<AType>(loc, tenv.U, 0);
+ matchee->constrain(tenv, c);
+ for (const_iterator i = begin() + 2; i != end();) {
+ const AST* exp = *i++;
+ const ATuple* pattern = exp->to<const ATuple*>();
+ THROW_IF(!pattern, exp->loc, "pattern expression expected");
+ const ASymbol* name = (*pattern->begin())->to<const ASymbol*>();
+ THROW_IF(!name, (*pattern->begin())->loc, "pattern does not start with a symbol");
+
+ const AType* consT = *tenv.ref(name);
+
+ if (!matcheeT) {
+ const AType* headT = consT->head()->as<const AType*>();
+ matcheeT = tup<AType>(loc, const_cast<AType*>(headT), 0);
+ }
+
+ THROW_IF(i == end(), pattern->loc, "missing pattern body");
+ const AST* body = *i++;
+ body->constrain(tenv, c);
+ c.constrain(tenv, body, retT);
+ }
+ c.constrain(tenv, this, retT);
+ c.constrain(tenv, matchee, matcheeT);
+}
+
+void
+ADefType::constrain(TEnv& tenv, Constraints& c) const throw(Error)
+{
+ THROW_IF(size() < 3, loc, "`def-type' requires at least 2 arguments");
+ const_iterator i = begin() + 1;
+ const ATuple* prot = (*i)->to<const ATuple*>();
+ THROW_IF(!prot, (*i)->loc, "first argument of `def-type' is not a tuple");
+ const ASymbol* sym = (*prot->begin())->as<const ASymbol*>();
+ THROW_IF(!sym, (*prot->begin())->loc, "type name is not a symbol");
+ THROW_IF(tenv.ref(sym), loc, "type redefinition");
+ AType* type = tup<AType>(loc, tenv.U, 0);
+ for (const_iterator i = begin() + 2; i != end(); ++i) {
+ const ATuple* exp = (*i)->as<const ATuple*>();
+ const ASymbol* tag = (*exp->begin())->as<const ASymbol*>();
+ AType* consT = new AType(exp->loc, AType::EXPR);
+ consT->push_back(new AType(const_cast<ASymbol*>(sym), AType::NAME));
+ for (ATuple::const_iterator i = exp->begin(); i != exp->end(); ++i) {
+ const ASymbol* sym = (*i)->to<const ASymbol*>();
+ THROW_IF(!sym, (*i)->loc, "type expression element is not a symbol");
+ consT->push_back(new AType(const_cast<ASymbol*>(sym), AType::NAME));
+ }
+ type->push_back(consT);
+ tenv.def(tag, consT);
+ }
+ tenv.def(sym, type);
+}
+
+void
ACons::constrain(TEnv& tenv, Constraints& c) const throw(Error)
{
- ASymbol* sym = (*begin())->as<ASymbol*>();
- AType* type = tup<AType>(loc, new AType(sym), 0);
- for (const_iterator i = begin() + 1; i != end(); ++i) {
+ const ASymbol* sym = (*begin())->as<const ASymbol*>();
+ const AType* type = NULL;
+
+ for (const_iterator i = begin() + 1; i != end(); ++i)
(*i)->constrain(tenv, c);
- type->push_back(const_cast<AType*>(tenv.var(*i)));
- }
+ if (sym->cppstr == "Tup") {
+ AType* tupT = tup<AType>(loc, tenv.Tup, 0);
+ for (const_iterator i = begin() + 1; i != end(); ++i) {
+ tupT->push_back(const_cast<AType*>(tenv.var(*i)));
+ }
+ type = tupT;
+ } else {
+ const AType** consTRef = tenv.ref(sym);
+ THROW_IF(!consTRef, loc, (format("call to undefined constructor `%1%'") % sym->cppstr).str());
+ const AType* consT = *consTRef;
+ type = tup<AType>(loc, const_cast<AType*>(consT->head()->as<const AType*>()), 0);
+ }
c.constrain(tenv, this, type);
}
diff --git a/src/cps.cpp b/src/cps.cpp
index 831f53f..aed7c33 100644
--- a/src/cps.cpp
+++ b/src/cps.cpp
@@ -120,6 +120,18 @@ ADef::cps(TEnv& tenv, AST* cont) const
return tup<ADef>(loc, tenv.penv.sym("def"), sym(), *++i, 0);
}
+AST*
+ADefType::cps(TEnv& tenv, AST* cont) const
+{
+ return new ADefType(*this);
+}
+
+AST*
+AMatch::cps(TEnv& tenv, AST* cont) const
+{
+ return new AMatch(*this);
+}
+
/** (cps (if c t ... e)) => */
AST*
AIf::cps(TEnv& tenv, AST* cont) const
diff --git a/src/llvm.cpp b/src/llvm.cpp
index 8b4deaa..24846cc 100644
--- a/src/llvm.cpp
+++ b/src/llvm.cpp
@@ -115,6 +115,7 @@ struct LLVMEngine : public Engine {
return PointerType::get(FunctionType::get(llType(retT), cprot, false), 0);
} else if (t->kind == AType::EXPR && isupper(t->head()->str()[0])) {
vector<const Type*> ctypes;
+ ctypes.push_back(PointerType::get(Type::getInt8Ty(context), NULL)); // RTTI
for (AType::const_iterator i = t->begin() + 1; i != t->end(); ++i) {
const Type* lt = llType((*i)->to<const AType*>());
if (!lt)
@@ -191,12 +192,13 @@ struct LLVMEngine : public Engine {
return builder.CreateCall(llFunc(f), llArgs.begin(), llArgs.end());
}
- CVal compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields);
+ CVal compileTup(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields);
CVal compileDot(CEnv& cenv, CVal tup, int32_t index);
CVal compileLiteral(CEnv& cenv, const AST* lit);
CVal compileString(CEnv& cenv, const char* str);
CVal compilePrimitive(CEnv& cenv, const APrimitive* prim);
CVal compileIf(CEnv& cenv, const AIf* aif);
+ CVal compileMatch(CEnv& cenv, const AMatch* match);
CVal compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal val);
CVal getGlobal(CEnv& cenv, const string& sym, CVal val);
@@ -271,10 +273,10 @@ bitsToBytes(size_t bits)
}
CVal
-LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields)
+LLVMEngine::compileTup(CEnv& cenv, const AType* type, CVal rtti, const vector<CVal>& fields)
{
// Find size of memory required
- size_t s = 0;
+ size_t s = engine->getTargetData()->getTypeSizeInBits(PointerType::get(Type::getInt8Ty(context), NULL));
assert(type->begin() != type->end());
for (AType::const_iterator i = type->begin() + 1; i != type->end(); ++i)
s += engine->getTargetData()->getTypeSizeInBits(llType((*i)->as<AType*>()));
@@ -285,7 +287,9 @@ LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields
Value* structPtr = builder.CreateBitCast(mem, llType(type), "tup");
// Set struct fields
- size_t i = 0;
+ if (rtti)
+ builder.CreateStore((Value*)rtti, builder.CreateStructGEP(structPtr, 0, "rtti"));
+ size_t i = 1;
for (vector<CVal>::const_iterator f = fields.begin(); f != fields.end(); ++f, ++i) {
builder.CreateStore(llVal(*f),
builder.CreateStructGEP(structPtr, i, (format("tup%1%") % i).str().c_str()));
@@ -297,7 +301,7 @@ LLVMEngine::compileTup(CEnv& cenv, const AType* type, const vector<CVal>& fields
CVal
LLVMEngine::compileDot(CEnv& cenv, CVal tup, int32_t index)
{
- Value* ptr = builder.CreateStructGEP(llVal(tup), index, "dotPtr");
+ Value* ptr = builder.CreateStructGEP(llVal(tup), index + 1, "dotPtr"); // +1 to skip RTTI
return builder.CreateLoad(ptr, 0, "dotVal");
}
@@ -382,7 +386,6 @@ LLVMEngine::compileIf(CEnv& cenv, const AIf* aif)
}
// Emit final else block
- engine->builder.SetInsertPoint(nextBB);
Value* elseV = llVal(aif->last()->compile(cenv));
engine->builder.CreateBr(mergeBB);
branches.push_back(make_pair(elseV, engine->builder.GetInsertBlock()));
@@ -399,6 +402,61 @@ LLVMEngine::compileIf(CEnv& cenv, const AIf* aif)
}
CVal
+LLVMEngine::compileMatch(CEnv& cenv, const AMatch* match)
+{
+ typedef vector< pair<Value*, BasicBlock*> > Branches;
+ Value* matchee = llVal((*(match->begin() + 1))->compile(cenv));
+ Value* rttiPtr = builder.CreateStructGEP(matchee, 0, "matchRTTIPtr");
+ Value* rtti = builder.CreateLoad(rttiPtr, 0, "matchRTTI");
+
+ LLVMEngine* engine = reinterpret_cast<LLVMEngine*>(cenv.engine());
+ Function* parent = engine->builder.GetInsertBlock()->getParent();
+ BasicBlock* mergeBB = BasicBlock::Create(context, "endmatch");
+ BasicBlock* nextBB = NULL;
+ Branches branches;
+
+ size_t idx = 1;
+ for (AMatch::const_iterator i = match->begin() + 2; i != match->end(); ++idx) {
+ const AST* pat = *i++;
+ const AST* body = *i++;
+ const ASymbol* sym = pat->to<const ATuple*>()->head()->as<const ASymbol*>();
+ const AType* patT = tup<AType>(Cursor(), const_cast<ASymbol*>(sym), 0);
+
+ Value* typeV = llVal(patT->compile(cenv));
+ Value* condV = engine->builder.CreateICmp(CmpInst::ICMP_EQ, rtti, typeV);
+ BasicBlock* thenBB = BasicBlock::Create(context, (format("case%1%") % ((idx+1)/2)).str());
+
+ nextBB = BasicBlock::Create(context, (format("otherwise%1%") % ((idx+1)/2)).str());
+
+ engine->builder.CreateCondBr(condV, thenBB, nextBB);
+
+ // Emit then block for this condition
+ parent->getBasicBlockList().push_back(thenBB);
+ engine->builder.SetInsertPoint(thenBB);
+ Value* thenV = llVal(body->compile(cenv));
+ engine->builder.CreateBr(mergeBB);
+ branches.push_back(make_pair(thenV, engine->builder.GetInsertBlock()));
+
+ parent->getBasicBlockList().push_back(nextBB);
+ engine->builder.SetInsertPoint(nextBB);
+ }
+
+ // Emit final else block (FIXME: n/a, what to do here?)
+ engine->builder.CreateBr(mergeBB);
+ branches.push_back(make_pair(Constant::getNullValue(llType(cenv.type(match))), engine->builder.GetInsertBlock()));
+
+ // Emit merge block (Phi node)
+ parent->getBasicBlockList().push_back(mergeBB);
+ engine->builder.SetInsertPoint(mergeBB);
+ PHINode* pn = engine->builder.CreatePHI(llType(cenv.type(match)), "mergeval");
+
+ FOREACH(Branches::iterator, i, branches)
+ pn->addIncoming(i->first, i->second);
+
+ return pn;
+}
+
+CVal
LLVMEngine::compilePrimitive(CEnv& cenv, const APrimitive* prim)
{
APrimitive::const_iterator i = prim->begin();
@@ -451,7 +509,9 @@ LLVMEngine::compileGlobal(CEnv& cenv, const AType* type, const string& sym, CVal
GlobalVariable* global = new GlobalVariable(*module, llType(type), false,
GlobalValue::ExternalLinkage, Constant::getNullValue(llType(type)), sym);
- engine->builder.CreateStore(llVal(val), global);
+ Value* valPtr = builder.CreateBitCast(llVal(val), llType(type), "globalPtr");
+
+ engine->builder.CreateStore(valPtr, global);
return global;
}
diff --git a/src/parse.cpp b/src/parse.cpp
index 7482039..28279f9 100644
--- a/src/parse.cpp
+++ b/src/parse.cpp
@@ -170,13 +170,13 @@ void
initLang(PEnv& penv, TEnv& tenv)
{
// Types
- tenv.def(penv.sym("Nothing"), new AType(penv.sym("Nothing")));
- tenv.def(penv.sym("Bool"), new AType(penv.sym("Bool")));
- tenv.def(penv.sym("Int"), new AType(penv.sym("Int")));
- tenv.def(penv.sym("Float"), new AType(penv.sym("Float")));
- tenv.def(penv.sym("String"), new AType(penv.sym("String")));
- tenv.def(penv.sym("Lexeme"), new AType(penv.sym("Lexeme")));
- tenv.def(penv.sym("Quote"), new AType(penv.sym("Quote")));
+ tenv.def(penv.sym("Nothing"), new AType(penv.sym("Nothing"), AType::PRIM));
+ tenv.def(penv.sym("Bool"), new AType(penv.sym("Bool"), AType::PRIM));
+ tenv.def(penv.sym("Int"), new AType(penv.sym("Int"), AType::PRIM));
+ tenv.def(penv.sym("Float"), new AType(penv.sym("Float"), AType::PRIM));
+ tenv.def(penv.sym("String"), new AType(penv.sym("String"), AType::PRIM));
+ tenv.def(penv.sym("Lexeme"), new AType(penv.sym("Lexeme"), AType::PRIM));
+ tenv.def(penv.sym("Quote"), new AType(penv.sym("Quote"), AType::PRIM));
// Literals
static bool trueVal = true;
@@ -188,11 +188,13 @@ initLang(PEnv& penv, TEnv& tenv)
penv.defmac("def", macDef);
// Special forms
- penv.reg(true, "fn", PEnv::Handler(parseFn));
- penv.reg(true, "quote", PEnv::Handler(parseQuote));
- penv.reg(true, "if", PEnv::Handler(parseCall<AIf>));
- penv.reg(true, ".", PEnv::Handler(parseCall<ADot>));
- penv.reg(true, "def", PEnv::Handler(parseCall<ADef>));
+ penv.reg(true, "fn", PEnv::Handler(parseFn));
+ penv.reg(true, "quote", PEnv::Handler(parseQuote));
+ penv.reg(true, "if", PEnv::Handler(parseCall<AIf>));
+ penv.reg(true, ".", PEnv::Handler(parseCall<ADot>));
+ penv.reg(true, "def", PEnv::Handler(parseCall<ADef>));
+ penv.reg(true, "def-type", PEnv::Handler(parseCall<ADefType>));
+ penv.reg(true, "match", PEnv::Handler(parseCall<AMatch>));
// Numeric primitives
penv.reg(true, "+", PEnv::Handler(parseCall<APrimitive>));
diff --git a/src/pprint.cpp b/src/pprint.cpp
index 76796d7..28df1cc 100644
--- a/src/pprint.cpp
+++ b/src/pprint.cpp
@@ -51,10 +51,11 @@ operator<<(ostream& out, const AST* ast)
const AType* type = ast->to<const AType*>();
if (type) {
switch (type->kind) {
- case AType::VAR: return out << "?" << type->id;
- case AType::PRIM: return out << type->head();
- case AType::DOTS: return out << "...";
- case AType::EXPR: break; // will catch Tuple case below
+ case AType::VAR: return out << "?" << type->id;
+ case AType::NAME: return out << type->head();
+ case AType::PRIM: return out << type->head();
+ case AType::DOTS: return out << "...";
+ case AType::EXPR: break; // will catch Tuple case below
}
}
diff --git a/src/repl.cpp b/src/repl.cpp
index 5850e66..0264c18 100644
--- a/src/repl.cpp
+++ b/src/repl.cpp
@@ -159,6 +159,9 @@ eval(CEnv& cenv, const string& name, istream& is, bool execute)
// Finish compilation
cenv.engine()->finishFunction(cenv, f, val);
+ if (cenv.args.find("-d") != cenv.args.end())
+ cenv.engine()->writeModule(cenv, cenv.out);
+
// Call and print ast
if (cenv.args.find("-S") == cenv.args.end())
callPrintCollect(cenv, f, ast, type, execute);
diff --git a/src/resp.hpp b/src/resp.hpp
index 6e4dc77..3145080 100644
--- a/src/resp.hpp
+++ b/src/resp.hpp
@@ -338,18 +338,19 @@ private:
/// Type Expression, e.g. "Int", "(Fn (Int Int) Float)"
struct AType : public ATuple {
- enum Kind { VAR, PRIM, EXPR, DOTS };
- AType(ASymbol* s) : ATuple(s->loc), kind(PRIM), id(0) { push_back(s); }
+ enum Kind { VAR, NAME, PRIM, EXPR, DOTS };
+ AType(ASymbol* s, Kind k) : ATuple(s->loc), kind(k), id(0) { push_back(s); }
AType(Cursor c, unsigned i) : ATuple(c), kind(VAR), id(i) {}
AType(Cursor c, Kind k=EXPR) : ATuple(c), kind(k), id(0) {}
AType(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args), kind(EXPR), id(0) {}
AType(const AType& copy) : ATuple(copy), kind(copy.kind), id(copy.id) { }
- CVal compile(CEnv& env) const throw() { return NULL; }
+ CVal compile(CEnv& cenv) const throw();
const ATuple* prot() const { assert(kind == EXPR); return (*(begin() + 1))->to<const ATuple*>(); }
ATuple* prot() { assert(kind == EXPR); return (*(begin() + 1))->to<ATuple*>(); }
bool concrete() const {
switch (kind) {
case VAR: return false;
+ case NAME: return false;
case PRIM: return head()->str() != "Nothing";
case EXPR:
FOREACHP(const_iterator, t, this) {
@@ -363,14 +364,16 @@ struct AType : public ATuple {
}
bool operator==(const AST& rhs) const {
const AType* rt = rhs.to<const AType*>();
- if (!rt || kind != rt->kind)
+ if (!rt || kind != rt->kind) {
+ assert(str() != rt->str());
return false;
- else
+ } else
switch (kind) {
- case VAR: return id == rt->id;
- case PRIM: return head()->str() == rt->head()->str();
- case EXPR: return ATuple::operator==(rhs);
- case DOTS: return true;
+ case VAR: return id == rt->id;
+ case NAME: return head()->str() == rt->head()->str();
+ case PRIM: return head()->str() == rt->head()->str();
+ case EXPR: return ATuple::operator==(rhs);
+ case DOTS: return true;
}
return false; // never reached
}
@@ -427,6 +430,27 @@ struct ADef : public ACall {
CVal compile(CEnv& env) const throw();
};
+struct ADefType : public ACall {
+ ADefType(const ATuple* exp) : ACall(exp) {}
+ ADefType(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {}
+ const ASymbol* sym() const { return (*(begin() + 1))->as<const ASymbol*>(); }
+ void constrain(TEnv& tenv, Constraints& c) const throw(Error);
+ AST* cps(TEnv& tenv, AST* cont) const;
+ AST* lift(CEnv& cenv, Code& code) throw() { return this; }
+ AST* depoly(CEnv& cenv, Code& code) throw() { return this; }
+ CVal compile(CEnv& env) const throw() { return NULL; }
+};
+
+struct AMatch : public ACall {
+ AMatch(const ATuple* exp) : ACall(exp) {}
+ AMatch(Cursor c, AST* ast, va_list args) : ACall(c, ast, args) {}
+ void constrain(TEnv& tenv, Constraints& c) const throw(Error);
+ AST* cps(TEnv& tenv, AST* cont) const;
+ AST* lift(CEnv& cenv, Code& code) throw() { return this; }
+ AST* depoly(CEnv& cenv, Code& code) throw() { return this; }
+ CVal compile(CEnv& env) const throw();
+};
+
/// Conditional special form, e.g. "(if cond thenexp elseexp)"
struct AIf : public ACall {
AIf(const ATuple* exp) : ACall(exp) {}
@@ -597,7 +621,7 @@ struct Constraints : public list<Constraint> {
inline ostream& operator<<(ostream& out, const Constraints& c) {
for (Constraints::const_iterator i = c.begin(); i != c.end(); ++i)
- out << i->first << " : " << i->second << endl;
+ out << i->first << " <= " << i->second << endl;
return out;
}
@@ -606,8 +630,9 @@ struct TEnv : public Env<const ASymbol*, const AType*> {
TEnv(PEnv& p)
: penv(p)
, varID(1)
- , Fn(new AType(penv.sym("Fn")))
- , Tup(new AType(penv.sym("Tup")))
+ , Fn(new AType(penv.sym("Fn"), AType::PRIM))
+ , Tup(new AType(penv.sym("Tup"), AType::NAME))
+ , U(new AType(penv.sym("U"), AType::PRIM))
{
Object::pool.addRoot(Fn);
}
@@ -641,6 +666,7 @@ struct TEnv : public Env<const ASymbol*, const AType*> {
AType* Fn;
AType* Tup;
+ AType* U;
};
Subst unify(const Constraints& c);
@@ -666,13 +692,14 @@ struct Engine {
virtual void finishFunction(CEnv& cenv, CFunc f, CVal ret) = 0;
virtual void eraseFunction(CEnv& cenv, CFunc f) = 0;
- virtual CVal compileTup(CEnv& cenv, const AType* t, ValVec& f) = 0;
+ virtual CVal compileTup(CEnv& cenv, const AType* t, CVal rtti, ValVec& f) = 0;
virtual CVal compileDot(CEnv& cenv, CVal tup, int32_t index) = 0;
virtual CVal compileLiteral(CEnv& cenv, const AST* lit) = 0;
virtual CVal compileString(CEnv& cenv, const char* str) = 0;
virtual CVal compileCall(CEnv& cenv, CFunc f, const AType* fT, ValVec& args) = 0;
virtual CVal compilePrimitive(CEnv& cenv, const APrimitive* prim) = 0;
virtual CVal compileIf(CEnv& cenv, const AIf* aif) = 0;
+ virtual CVal compileMatch(CEnv& cenv, const AMatch* match) = 0;
virtual CVal compileGlobal(CEnv& cenv, const AType* t, const string& sym, CVal val) = 0;
virtual CVal getGlobal(CEnv& cenv, const string& sym, CVal val) = 0;
virtual void writeModule(CEnv& cenv, std::ostream& os) = 0;
diff --git a/src/unify.cpp b/src/unify.cpp
index 1d3af81..a4ea035 100644
--- a/src/unify.cpp
+++ b/src/unify.cpp
@@ -147,8 +147,8 @@ unify(const Constraints& constraints)
} else if (t->kind == AType::VAR && !s->contains(t)) {
return Subst::compose(unify(cp.replace(t, s)), Subst(t, s));
} else if (s->kind == AType::EXPR && t->kind == AType::EXPR) {
- AType::const_iterator si = s->begin() + 1;
- AType::const_iterator ti = t->begin() + 1;
+ AType::const_iterator si = s->begin();
+ AType::const_iterator ti = t->begin();
for (; si != s->end() && ti != t->end(); ++si, ++ti) {
AType* st = (*si)->as<AType*>();
AType* tt = (*ti)->as<AType*>();
diff --git a/test.sh b/test.sh
index 9c67dd7..53ce693 100755
--- a/test.sh
+++ b/test.sh
@@ -23,5 +23,6 @@ run './test/nest.resp' '8 : Int'
run './test/tup.resp' '5 : Int'
run './test/string.resp' '"Hello, world!" : String'
run './test/quote.resp' '(quote hello) : Quote'
+run './test/match.resp' '"Hello, rectangle!" : String'
#run './test/poly.resp' '#t : Bool'
diff --git a/test/match.resp b/test/match.resp
new file mode 100644
index 0000000..b50257a
--- /dev/null
+++ b/test/match.resp
@@ -0,0 +1,14 @@
+(def-type (Shape)
+ (Circle Float)
+ (Square Float)
+ (Rectangle Float Float))
+
+(def c1 (Circle 1.0))
+(def c2 (Circle 2.0))
+(def s1 (Square 1.0))
+(def r1 (Rectangle 1.0 1.0))
+
+(match r1
+ (Circle r) "Hello, circle!"
+ (Square w) "Hello, square!"
+ (Rectangle w h) "Hello, rectangle!")