aboutsummaryrefslogtreecommitdiffstats
path: root/src/constrain.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/constrain.cpp')
-rw-r--r--src/constrain.cpp112
1 files changed, 57 insertions, 55 deletions
diff --git a/src/constrain.cpp b/src/constrain.cpp
index 4e507c9..cb7b700 100644
--- a/src/constrain.cpp
+++ b/src/constrain.cpp
@@ -90,11 +90,11 @@ AFn::constrain(TEnv& tenv, Constraints& c) const throw(Error)
// Add internal definitions to environment frame
for (++i; i != end(); ++i) {
- const AST* exp = *i;
- const ADef* def = exp->to<const ADef*>();
- if (def) {
- const ASymbol* sym = def->list_ref(1)->as<const ASymbol*>();
- THROW_IF(defs.count(sym) != 0, def->loc,
+ const AST* exp = *i;
+ const ACall* call = exp->to<const ACall*>();
+ if (call && is_form(call, "def")) {
+ const ASymbol* sym = call->list_ref(1)->as<const ASymbol*>();
+ THROW_IF(defs.count(sym) != 0, call->loc,
(format("`%1%' defined twice") % sym->str()).str());
defs.insert(sym);
frame.push_back(make_pair(sym, (AType*)NULL));
@@ -296,38 +296,10 @@ constrain_call(TEnv& tenv, Constraints& c, const ACall* call) throw(Error)
c.constrain(tenv, call, retT);
}
-void
-ACall::constrain(TEnv& tenv, Constraints& c) const throw(Error)
-{
- const ASymbol* const sym = head()->to<const ASymbol*>();
- if (!sym) {
- constrain_call(tenv, c, this);
- return;
- }
-
- const std::string form = sym->cppstr;
- if (form == "def")
- constrain_def(tenv, c, this);
- else if (form == "def-type")
- constrain_def_type(tenv, c, this);
- else if (form == "match")
- constrain_match(tenv, c, this);
- else if (form == "if")
- constrain_if(tenv, c, this);
- else if (form == "cons" || isupper(form[0]))
- constrain_cons(tenv, c, this);
- else if (form == ".")
- constrain_dot(tenv, c, this);
- else if (form == "quote")
- constrain_quote(tenv, c, this);
- else
- constrain_call(tenv, c, this);
-}
-
-void
-APrimitive::constrain(TEnv& tenv, Constraints& c) const throw(Error)
+static void
+constrain_primitive(TEnv& tenv, Constraints& c, const ACall* call) throw(Error)
{
- const string n = head()->to<const ASymbol*>()->str();
+ const string n = call->head()->to<const ASymbol*>()->str();
enum { ARITHMETIC, BINARY, LOGICAL, COMPARISON } type;
if (n == "+" || n == "-" || n == "*" || n == "/")
type = ARITHMETIC;
@@ -338,44 +310,74 @@ APrimitive::constrain(TEnv& tenv, Constraints& c) const throw(Error)
else if (n == "=" || n == "!=" || n == ">" || n == ">=" || n == "<" || n == "<=")
type = COMPARISON;
else
- throw Error(loc, (format("unknown primitive `%1%'") % n).str());
+ throw Error(call->loc, (format("unknown primitive `%1%'") % n).str());
- ATuple::const_iterator i = begin();
+ ATuple::const_iterator i = call->begin();
- for (++i; i != end(); ++i)
+ for (++i; i != call->end(); ++i)
(*i)->constrain(tenv, c);
- i = begin();
+ i = call->begin();
const AType* var = NULL;
switch (type) {
case ARITHMETIC:
- if (list_len() < 3)
- throw Error(loc, (format("`%1%' requires at least 2 arguments") % n).str());
- for (++i; i != end(); ++i)
- c.constrain(tenv, *i, tenv.var(this));
+ if (call->list_len() < 3)
+ throw Error(call->loc, (format("`%1%' requires at least 2 arguments") % n).str());
+ for (++i; i != call->end(); ++i)
+ c.constrain(tenv, *i, tenv.var(call));
break;
case BINARY:
- if (list_len() != 3)
- throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
- c.constrain(tenv, *++i, tenv.var(this));
- c.constrain(tenv, *++i, tenv.var(this));
+ if (call->list_len() != 3)
+ throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str());
+ c.constrain(tenv, *++i, tenv.var(call));
+ c.constrain(tenv, *++i, tenv.var(call));
break;
case LOGICAL:
- if (list_len() != 3)
- throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
- c.constrain(tenv, this, tenv.named("Bool"));
+ if (call->list_len() != 3)
+ throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str());
+ c.constrain(tenv, call, tenv.named("Bool"));
c.constrain(tenv, *++i, tenv.named("Bool"));
c.constrain(tenv, *++i, tenv.named("Bool"));
break;
case COMPARISON:
- if (list_len() != 3)
- throw Error(loc, (format("`%1%' requires exactly 2 arguments") % n).str());
+ if (call->list_len() != 3)
+ throw Error(call->loc, (format("`%1%' requires exactly 2 arguments") % n).str());
var = tenv.var(*++i);
- c.constrain(tenv, this, tenv.named("Bool"));
+ c.constrain(tenv, call, tenv.named("Bool"));
c.constrain(tenv, *++i, var);
break;
default:
- throw Error(loc, (format("unknown primitive `%1%'") % n).str());
+ throw Error(call->loc, (format("unknown primitive `%1%'") % n).str());
+ }
+}
+
+void
+ACall::constrain(TEnv& tenv, Constraints& c) const throw(Error)
+{
+ const ASymbol* const sym = head()->to<const ASymbol*>();
+ if (!sym) {
+ constrain_call(tenv, c, this);
+ return;
}
+
+ const std::string form = sym->cppstr;
+ if (is_primitive(tenv.penv, this))
+ constrain_primitive(tenv, c, this);
+ else if (form == "def")
+ constrain_def(tenv, c, this);
+ else if (form == "def-type")
+ constrain_def_type(tenv, c, this);
+ else if (form == "match")
+ constrain_match(tenv, c, this);
+ else if (form == "if")
+ constrain_if(tenv, c, this);
+ else if (form == "cons" || isupper(form[0]))
+ constrain_cons(tenv, c, this);
+ else if (form == ".")
+ constrain_dot(tenv, c, this);
+ else if (form == "quote")
+ constrain_quote(tenv, c, this);
+ else
+ constrain_call(tenv, c, this);
}