/*
  Copyright 2011-2018 David Robillard <http://drobilla.net>

  Permission to use, copy, modify, and/or distribute this software for any
  purpose with or without fee is hereby granted, provided that the above
  copyright notice and this permission notice appear in all copies.

  THIS SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/

#include "model.h"

#include "iter.h"
#include "log.h"
#include "node.h"
#include "nodes.h"
#include "statement.h"
#include "world.h"

#include "zix/btree.h"

#include <assert.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdlib.h>

#define DEFAULT_ORDER SPO
#define DEFAULT_GRAPH_ORDER GSPO

/**
   Compare quads lexicographically, ignoring graph.

   NULL IDs (equal to 0) are treated as wildcards, always less than every
   other possible ID, except itself.
*/
static int
serd_triple_compare(const void* x_ptr, const void* y_ptr, void* user_data)
{
	const int* const           ordering = (const int*)user_data;
	const SerdStatement* const x        = (const SerdStatement*)x_ptr;
	const SerdStatement* const y        = (const SerdStatement*)y_ptr;

	for (int i = 0; i < TUP_LEN - 1; ++i) {
		const int idx = ordering[i];
		assert(idx != SERD_GRAPH);

		const int cmp = serd_node_compare(x->nodes[idx], y->nodes[idx], true);
		if (cmp) {
			return cmp;
		}
	}

	return 0;
}

/**
   Compare quads lexicographically, with exact (non-wildcard) graph matching.
*/
static int
serd_quad_compare(const void* x_ptr, const void* y_ptr, void* user_data)
{
	const int* const           ordering = (const int*)user_data;
	const SerdStatement* const x        = (const SerdStatement*)x_ptr;
	const SerdStatement* const y        = (const SerdStatement*)y_ptr;

	// Compare graph without wildcard matching
	const int gcmp = serd_node_compare(
		x->nodes[SERD_GRAPH], y->nodes[SERD_GRAPH], false);
	if (gcmp) {
		return gcmp;
	}

	// Compare triple fields in appropriate order with wildcard matching
	for (int i = 0; i < TUP_LEN; ++i) {
		const int idx = ordering[i];
		if (idx != SERD_GRAPH) {
			const int cmp =
				serd_node_compare(x->nodes[idx], y->nodes[idx], true);
			if (cmp) {
				return cmp;
			}
		}
	}

	return 0;
}

/**
   Return true iff `serd` has an index for `order`.
   If `graphs` is true, `order` will be modified to be the
   corresponding order with a G prepended (so G will be the MSN).
*/
static inline bool
serd_model_has_index(SerdModel* model,
                     SerdOrder* order,
                     int*       n_prefix,
                     bool       graphs)
{
	if (graphs) {
		*order = (SerdOrder)(*order + GSPO);
		*n_prefix += 1;
	}

	return model->indices[*order];
}

/**
   Return the best available index for a pattern.
   @param pat Pattern in standard (S P O G) order
   @param mode Set to the (best) iteration mode for iterating over results
   @param n_prefix Set to the length of the range prefix
   (for `mode` == RANGE and `mode` == FILTER_RANGE)
*/
static inline SerdOrder
serd_model_best_index(SerdModel*     model,
                      const SerdQuad pat,
                      SearchMode*    mode,
                      int*           n_prefix)
{
	const bool graph_search = (pat[SERD_GRAPH] != 0);

	const unsigned sig = ((pat[0] ? 1 : 0) * 0x100 +
	                      (pat[1] ? 1 : 0) * 0x010 +
	                      (pat[2] ? 1 : 0) * 0x001);

	SerdOrder good[2] = { (SerdOrder)-1, (SerdOrder)-1 };

#define PAT_CASE(sig, m, g0, g1, np)                                           \
	case sig:                                                                  \
		*mode     = m;                                                         \
		good[0]   = g0;                                                        \
		good[1]   = g1;                                                        \
		*n_prefix = np;                                                        \
		break

	// Good orderings that don't require filtering
	*mode     = RANGE;
	*n_prefix = 0;
	switch (sig) {
	case 0x000:
		assert(graph_search);
		*mode     = RANGE;
		*n_prefix = 1;
		return DEFAULT_GRAPH_ORDER;
	case 0x111:
		*mode = SINGLE;
		return graph_search ? DEFAULT_GRAPH_ORDER : DEFAULT_ORDER;

		PAT_CASE(0x001, RANGE, OPS, OSP, 1);
		PAT_CASE(0x010, RANGE, POS, PSO, 1);
		PAT_CASE(0x011, RANGE, OPS, POS, 2);
		PAT_CASE(0x100, RANGE, SPO, SOP, 1);
		PAT_CASE(0x101, RANGE, SOP, OSP, 2);
		PAT_CASE(0x110, RANGE, SPO, PSO, 2);
	}

	if (*mode == RANGE) {
		if (serd_model_has_index(model, &good[0], n_prefix, graph_search)) {
			return good[0];
		} else if (serd_model_has_index(
			           model, &good[1], n_prefix, graph_search)) {
			return good[1];
		}
	}

	// Not so good orderings that require filtering, but can
	// still be constrained to a range
	switch (sig) {
		PAT_CASE(0x011, FILTER_RANGE, OSP, PSO, 1);
		PAT_CASE(0x101, FILTER_RANGE, SPO, OPS, 1);
		// SPO is always present, so 0x110 is never reached here
	default: break;
	}

	if (*mode == FILTER_RANGE) {
		if (serd_model_has_index(model, &good[0], n_prefix, graph_search)) {
			return good[0];
		} else if (serd_model_has_index(
			           model, &good[1], n_prefix, graph_search)) {
			return good[1];
		}
	}

	if (graph_search) {
		*mode     = FILTER_RANGE;
		*n_prefix = 1;
		return DEFAULT_GRAPH_ORDER;
	} else {
		*mode = FILTER_ALL;
		return DEFAULT_ORDER;
	}
}

SerdModel*
serd_model_new(SerdWorld* world, unsigned indices, SerdModelFlags flags)
{
	SerdModel* model = (SerdModel*)calloc(1, sizeof(struct SerdModelImpl));
	model->world     = world;

	indices |= SERD_SPO;

	for (unsigned i = 0; i < (NUM_ORDERS / 2); ++i) {
		const int* const ordering   = orderings[i];
		const int* const g_ordering = orderings[i + (NUM_ORDERS / 2)];

		if (indices & (1 << i)) {
			model->indices[i] =
				zix_btree_new((ZixComparator)serd_triple_compare,
				              (const void*)ordering,
				              NULL);
			if (flags & SERD_MODEL_GRAPHS) {
				model->indices[i + (NUM_ORDERS / 2)] =
					zix_btree_new((ZixComparator)serd_quad_compare,
					              (const void*)g_ordering,
					              NULL);
			}
		}
	}

	return model;
}

static void
serd_model_drop_statement(SerdModel* model, SerdStatement* statement)

{
	for (int i = 0; i < TUP_LEN; ++i) {
		if (statement->nodes[i]) {
			serd_nodes_deref(model->world->nodes, statement->nodes[i]);
		}
	}

	free(statement);
}

void
serd_model_free(SerdModel* model)
{
	if (!model) {
		return;
	}

	// Free quads
	ZixBTree* index = model->indices[model->indices[DEFAULT_GRAPH_ORDER]
	                                 ? DEFAULT_GRAPH_ORDER
	                                 : DEFAULT_ORDER];
	ZixBTreeIter* t = zix_btree_begin(index);
	for (; !zix_btree_iter_is_end(t); zix_btree_iter_increment(t)) {
		free(zix_btree_get(t));
	}
	zix_btree_iter_free(t);

	// Free indices
	for (unsigned o = 0; o < NUM_ORDERS; ++o) {
		if (model->indices[o]) {
			zix_btree_free(model->indices[o]);
		}
	}

	free(model);
}

SerdWorld*
serd_model_get_world(SerdModel* model)
{
	return model->world;
}

size_t
serd_model_num_quads(const SerdModel* model)
{
	const SerdOrder order = model->indices[GSPO] ? GSPO : SPO;
	return zix_btree_size(model->indices[order]);
}

SerdIter*
serd_model_begin(const SerdModel* model)
{
	if (serd_model_num_quads(model) == 0) {
		return NULL;
	} else {
		const SerdOrder order = model->indices[GSPO] ? GSPO : SPO;
		ZixBTreeIter*   cur   = zix_btree_begin(model->indices[order]);
		SerdQuad        pat   = { 0, 0, 0, 0 };
		return serd_iter_new(model, cur, pat, order, ALL, 0);
	}
}

SerdIter*
serd_model_find(SerdModel*      model,
                const SerdNode* s,
                const SerdNode* p,
                const SerdNode* o,
                const SerdNode* g)
{
	const SerdQuad pat = { s, p, o, g };
	if (!pat[0] && !pat[1] && !pat[2] && !pat[3]) {
		return serd_model_begin(model);
	}

	SearchMode      mode;
	int             n_prefix;
	const SerdOrder index_order =
		serd_model_best_index(model, pat, &mode, &n_prefix);

	SERD_FIND_LOG("Find " TUP_FMT "  index=%s  mode=%d  n_prefix=%d\n",
	              TUP_FMT_ARGS(pat),
	              order_names[index_order],
	              mode,
	              n_prefix);

	if (pat[0] && pat[1] && pat[2] && pat[3]) {
		mode = SINGLE; // No duplicate quads (Serd is a set)
	}

	ZixBTree* const db  = model->indices[index_order];
	ZixBTreeIter*   cur = NULL;
	zix_btree_lower_bound(db, pat, &cur);
	if (zix_btree_iter_is_end(cur)) {
		SERD_FIND_LOG("No match found, iterator at end\n");
		zix_btree_iter_free(cur);
		return NULL;
	}
	const SerdStatement* const key = (const SerdStatement*)zix_btree_get(cur);
	if (!key || ((mode == RANGE || mode == SINGLE) &&
	             !serd_statement_matches_quad(key, pat))) {
		SERD_FIND_LOG("No match found, cursor at " TUP_FMT "\n",
		              TUP_FMT_ARGS(key->nodes));

		zix_btree_iter_free(cur);
		return NULL;
	}

	return serd_iter_new(model, cur, pat, index_order, mode, n_prefix);
}

const SerdNode*
serd_model_get(SerdModel*      model,
               const SerdNode* s,
               const SerdNode* p,
               const SerdNode* o,
               const SerdNode* g)
{
	if ((bool)s + (bool)p + (bool)o != 2) {
		return NULL;
	}

	SerdIter* const i = serd_model_find(model, s, p, o, g);
	if (i) {
		const SerdStatement* statement = serd_iter_get(i);
		serd_iter_free(i);

		if (!s) {
			return serd_statement_get_subject(statement);
		} else if (!p) {
			return serd_statement_get_predicate(statement);
		} else if (!o) {
			return serd_statement_get_object(statement);
		}
	}

	return NULL;
}

uint64_t
serd_model_count(SerdModel*      model,
                 const SerdNode* s,
                 const SerdNode* p,
                 const SerdNode* o,
                 const SerdNode* g)
{
	SerdIter* i = serd_model_find(model, s, p, o, g);
	uint64_t  n = 0;
	for (; !serd_iter_end(i); serd_iter_next(i)) {
		++n;
	}
	serd_iter_free(i);
	return n;
}

bool
serd_model_ask(SerdModel*      model,
               const SerdNode* s,
               const SerdNode* p,
               const SerdNode* o,
               const SerdNode* g)
{
	SerdIter* iter = serd_model_find(model, s, p, o, g);
	bool      ret  = (iter != NULL);
	serd_iter_free(iter);
	return ret;
}

SerdStatus
serd_model_add(SerdModel*      model,
               const SerdNode* s,
               const SerdNode* p,
               const SerdNode* o,
               const SerdNode* g)
{
	if (!s || !p || !o) {
		return serd_world_errorf(model->world,
		                         SERD_ERR_BAD_ARG,
		                         "attempt to add statement with NULL field\n");
	}

	SerdStatement* statement = (SerdStatement*)malloc(sizeof(SerdStatement));
	statement->nodes[0]      = serd_nodes_intern(model->world->nodes, s);
	statement->nodes[1]      = serd_nodes_intern(model->world->nodes, p);
	statement->nodes[2]      = serd_nodes_intern(model->world->nodes, o);
	statement->nodes[3]      = serd_nodes_intern(model->world->nodes, g);
	SERD_WRITE_LOG("Add " TUP_FMT "\n", TUP_FMT_ARGS(statement->nodes));

	bool added = false;
	for (unsigned i = 0; i < NUM_ORDERS; ++i) {
		if (model->indices[i]) {
			if (!zix_btree_insert(model->indices[i], statement)) {
				added = true;
			} else if (i == GSPO) {
				break; // Tuple already indexed
			}
		}
	}

	++model->version;
	if (added) {
		return SERD_SUCCESS;
	}

	serd_model_drop_statement(model, statement);
	return SERD_FAILURE;
}

SerdStatus
serd_model_add_statement(SerdModel* model, const SerdStatement* statement)
{
	return serd_model_add(model,
	                      serd_statement_get_subject(statement),
	                      serd_statement_get_predicate(statement),
	                      serd_statement_get_object(statement),
	                      serd_statement_get_graph(statement));
}

SerdStatus
serd_model_erase(SerdModel* model, SerdIter* iter)
{
	const SerdStatement* statement = serd_iter_get(iter);

	SERD_WRITE_LOG("Remove " TUP_FMT "\n", TUP_FMT_ARGS(statement->nodes));

	SerdStatement* removed = NULL;
	for (unsigned i = 0; i < NUM_ORDERS; ++i) {
		if (model->indices[i]) {
			zix_btree_remove(model->indices[i],
			                 statement,
			                 (void**)&removed,
			                 i == iter->order ? &iter->cur : NULL);
			}
		}
	}
	iter->end = zix_btree_iter_is_end(iter->cur);
	serd_iter_scan_next(iter);

	serd_model_drop_statement(model, removed);
	iter->version = ++model->version;

	return SERD_SUCCESS;
}