aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/constrain.cpp1
-rw-r--r--src/resp.hpp114
2 files changed, 75 insertions, 40 deletions
diff --git a/src/constrain.cpp b/src/constrain.cpp
index 0383f20..fd4e9f4 100644
--- a/src/constrain.cpp
+++ b/src/constrain.cpp
@@ -356,6 +356,7 @@ resp_constrain(TEnv& tenv, Constraints& c, const AST* ast) throw(Error)
{
switch (ast->tag()) {
case T_UNKNOWN:
+ case T_TYPE:
break;
case T_BOOL:
c.constrain(tenv, ast, tenv.named("Bool"));
diff --git a/src/resp.hpp b/src/resp.hpp
index f2f27f5..a092fb9 100644
--- a/src/resp.hpp
+++ b/src/resp.hpp
@@ -141,7 +141,8 @@ enum Tag {
T_LEXEME = 1<<5,
T_STRING = 1<<6,
T_SYMBOL = 1<<7,
- T_TUPLE = 1<<8
+ T_TUPLE = 1<<8,
+ T_TYPE = 1<<9
};
/// Garbage collector
@@ -209,7 +210,7 @@ typedef list<AST*> Code;
struct AST : public Object {
AST(Tag t, Cursor c=Cursor()) : loc(c) { this->tag(t); }
virtual ~AST() {}
- virtual bool operator==(const AST& o) const = 0;
+ bool operator==(const AST& o) const;
string str() const { ostringstream ss; ss << this; return ss.str(); }
template<typename T> T to() { return dynamic_cast<T>(this); }
template<typename T> T const to() const { return dynamic_cast<T const>(this); }
@@ -238,30 +239,23 @@ static T* tup(Cursor c, AST* ast, ...)
template<typename T>
struct ALiteral : public AST {
ALiteral(Tag tag, T v, Cursor c) : AST(tag, c), val(v) {}
- bool operator==(const AST& rhs) const {
- const ALiteral<T>* r = rhs.to<const ALiteral<T>*>();
- return (r && (val == r->val));
- }
const T val;
};
/// Lexeme (any atom in the CST, e.g. "a", "3.4", ""hello"", etc.)
struct ALexeme : public AST {
ALexeme(Cursor c, const string& s) : AST(T_LEXEME, c), cppstr(s) {}
- bool operator==(const AST& rhs) const { return this == &rhs; }
const string cppstr;
};
/// String, e.g. ""a""
struct AString : public AST {
AString(Cursor c, const string& s) : AST(T_STRING, c), cppstr(s) {}
- bool operator==(const AST& rhs) const { return this == &rhs; }
const string cppstr;
};
/// Symbol, e.g. "a"
struct ASymbol : public AST {
- bool operator==(const AST& rhs) const { return this == &rhs; }
const string cppstr;
private:
friend class PEnv;
@@ -413,16 +407,6 @@ struct ATuple : public AST {
AST*& list_ref(unsigned index) { return *iter_at(index); }
const AST* list_ref(unsigned index) const { return *iter_at(index); }
- bool operator==(const AST& rhs) const {
- const ATuple* rt = rhs.to<const ATuple*>();
- if (!rt || rt->tup_len() != tup_len()) return false;
- const_iterator l = begin();
- FOREACHP(const_iterator, r, rt)
- if (!(*(*l++) == *(*r)))
- return false;
- return true;
- }
-
const ATuple* prot() const { return list_ref(1)->as<const ATuple*>(); }
ATuple* prot() { return list_ref(1)->as<ATuple*>(); }
void set_prot(ATuple* prot) { *iter_at(1) = prot; }
@@ -452,12 +436,12 @@ list_contains(const ATuple* head, const AST* child) {
/// Type Expression, e.g. "Int", "(Fn (Int Int) Float)"
struct AType : public ATuple {
enum Kind { VAR, NAME, PRIM, EXPR, DOTS };
- AType(ASymbol* s, Kind k) : ATuple(s, NULL, s->loc), kind(k), id(0) {}
- AType(Cursor c, unsigned i) : ATuple(c), kind(VAR), id(i) {}
- AType(Cursor c, Kind k=EXPR) : ATuple(c), kind(k), id(0) {}
- AType(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args), kind(EXPR), id(0) {}
- AType(AST* first, AST* rest, Cursor c) : ATuple(first, rest, c), kind(EXPR), id(0) {}
- AType(const AType& copy) : ATuple(copy), kind(copy.kind), id(copy.id) {}
+ AType(ASymbol* s, Kind k) : ATuple(s, NULL, s->loc), kind(k), id(0) { tag(T_TYPE); }
+ AType(Cursor c, unsigned i) : ATuple(c), kind(VAR), id(i) { tag(T_TYPE); }
+ AType(Cursor c, Kind k=EXPR) : ATuple(c), kind(k), id(0) { tag(T_TYPE); }
+ AType(Cursor c, AST* ast, va_list args) : ATuple(c, ast, args), kind(EXPR), id(0) { tag(T_TYPE); }
+ AType(AST* first, AST* rest, Cursor c) : ATuple(first, rest, c), kind(EXPR), id(0) { tag(T_TYPE); }
+ AType(const AType& copy) : ATuple(copy), kind(copy.kind), id(copy.id) { tag(T_TYPE); }
bool concrete() const {
switch (kind) {
@@ -474,21 +458,6 @@ struct AType : public ATuple {
}
return true;
}
- bool operator==(const AST& rhs) const {
- const AType* rt = rhs.to<const AType*>();
- if (!rt || kind != rt->kind) {
- assert(str() != rt->str());
- return false;
- } else
- switch (kind) {
- case VAR: return id == rt->id;
- case NAME: return head()->str() == rt->head()->str();
- case PRIM: return head()->str() == rt->head()->str();
- case EXPR: return ATuple::operator==(rhs);
- case DOTS: return true;
- }
- return false; // never reached
- }
Kind kind;
unsigned id;
};
@@ -530,6 +499,71 @@ struct List {
typedef List<AType, AType> TList;
+inline bool
+list_equals(const ATuple* lhs, const ATuple* rhs)
+{
+ if (!rhs || rhs->tup_len() != lhs->tup_len()) return false;
+ ATuple::const_iterator l = lhs->begin();
+ FOREACHP(ATuple::const_iterator, r, rhs)
+ if (!(*(*l++) == *(*r)))
+ return false;
+ return true;
+}
+
+template<typename T>
+inline bool
+literal_equals(const ALiteral<T>* lhs, const ALiteral<T>* rhs)
+{
+ return lhs && rhs && lhs->val == rhs->val;
+}
+
+inline bool
+AST::operator==(const AST& rhs) const
+{
+ const Tag tag = this->tag();
+ if (tag != rhs.tag())
+ return false;
+
+ switch (tag) {
+ case T_BOOL:
+ return literal_equals(this->as<const ALiteral<bool>*>(), rhs.as<const ALiteral<bool>*>());
+ case T_FLOAT:
+ return literal_equals(this->as<const ALiteral<float>*>(), rhs.as<const ALiteral<float>*>());
+ case T_INT32:
+ return literal_equals(this->as<const ALiteral<int32_t>*>(), rhs.as<const ALiteral<int32_t>*>());
+ case T_TUPLE:
+ {
+ const ATuple* me = this->as<const ATuple*>();
+ const ATuple* rt = rhs.to<const ATuple*>();
+ return list_equals(me, rt);
+ }
+ case T_TYPE:
+ {
+ const AType* me = this->as<const AType*>();
+ const AType* rt = rhs.to<const AType*>();
+ if (!rt || me->kind != rt->kind) {
+ assert(str() != rt->str());
+ return false;
+ } else
+ switch (me->kind) {
+ case AType::VAR: return me->id == rt->id;
+ case AType::NAME: return me->head()->str() == rt->head()->str();
+ case AType::PRIM: return me->head()->str() == rt->head()->str();
+ case AType::EXPR: return list_equals(me, rt);
+ case AType::DOTS: return true;
+ }
+ return false; // never reached
+ }
+
+ case T_UNKNOWN:
+ case T_LEXEME:
+ case T_STRING:
+ case T_SYMBOL:
+ return this == &rhs;
+ }
+ return false;
+}
+
/***************************************************************************
* Parser: S-Expressions (SExp) -> AST Nodes (AST) *
***************************************************************************/