// Copyright 2011-2023 David Robillard <d@drobilla.net>
// SPDX-License-Identifier: ISC

#undef NDEBUG

#include "serd/buffer.h"
#include "serd/env.h"
#include "serd/event.h"
#include "serd/memory.h"
#include "serd/node.h"
#include "serd/output_stream.h"
#include "serd/sink.h"
#include "serd/statement.h"
#include "serd/status.h"
#include "serd/string_view.h"
#include "serd/syntax.h"
#include "serd/world.h"
#include "serd/writer.h"

#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

static void
test_writer_new(void)
{
  SerdWorld*       world  = serd_world_new();
  SerdEnv*         env    = serd_env_new(serd_empty_string());
  SerdBuffer       buffer = {NULL, 0};
  SerdOutputStream output = serd_open_output_buffer(&buffer);

  assert(!serd_writer_new(world, SERD_TURTLE, 0U, env, &output, 0U));

  serd_world_free(world);
  serd_env_free(env);
}

static void
test_write_bad_event(void)
{
  SerdWorld*       world  = serd_world_new();
  SerdEnv*         env    = serd_env_new(serd_empty_string());
  SerdBuffer       buffer = {NULL, 0};
  SerdOutputStream output = serd_open_output_buffer(&buffer);

  SerdWriter* writer =
    serd_writer_new(world, SERD_TURTLE, 0U, env, &output, 1U);
  assert(writer);

  const SerdEvent event = {(SerdEventType)42};
  assert(serd_sink_write_event(serd_writer_sink(writer), &event) ==
         SERD_BAD_ARG);

  assert(!serd_close_output(&output));

  char* const out = (char*)buffer.buf;
  assert(out);
  assert(!strcmp(out, ""));
  serd_free(out);

  serd_writer_free(writer);
  serd_env_free(env);
  serd_world_free(world);
}

static void
test_write_bad_prefix(void)
{
  SerdWorld*       world  = serd_world_new();
  SerdEnv*         env    = serd_env_new(serd_empty_string());
  SerdBuffer       buffer = {NULL, 0};
  SerdOutputStream output = serd_open_output_buffer(&buffer);
  SerdWriter*      writer =
    serd_writer_new(world, SERD_TURTLE, 0U, env, &output, 1U);

  assert(writer);

  SerdNode* name = serd_new_string(serd_string("eg"));
  SerdNode* uri  = serd_new_uri(serd_string("rel"));

  assert(serd_sink_write_prefix(serd_writer_sink(writer), name, uri) ==
         SERD_BAD_ARG);

  serd_buffer_close(&buffer);

  char* const out = (char*)buffer.buf;
  assert(!strcmp(out, ""));
  serd_free(out);

  serd_node_free(uri);
  serd_node_free(name);
  serd_writer_free(writer);
  serd_env_free(env);
  serd_world_free(world);
}

static void
test_write_long_literal(void)
{
  SerdWorld*       world  = serd_world_new();
  SerdEnv*         env    = serd_env_new(serd_empty_string());
  SerdBuffer       buffer = {NULL, 0};
  SerdOutputStream output = serd_open_output_buffer(&buffer);

  SerdWriter* writer =
    serd_writer_new(world, SERD_TURTLE, 0U, env, &output, 1U);
  assert(writer);

  SerdNode* s = serd_new_uri(serd_string("http://example.org/s"));
  SerdNode* p = serd_new_uri(serd_string("http://example.org/p"));
  SerdNode* o = serd_new_string(serd_string("hello \"\"\"world\"\"\"!"));

  assert(!serd_sink_write(serd_writer_sink(writer), 0, s, p, o, NULL));

  serd_node_free(o);
  serd_node_free(p);
  serd_node_free(s);
  serd_writer_free(writer);
  serd_close_output(&output);
  serd_env_free(env);
  serd_buffer_close(&buffer);

  char* const out = (char*)buffer.buf;

  static const char* const expected =
    "<http://example.org/s>\n"
    "\t<http://example.org/p> \"\"\"hello \"\"\\\"world\"\"\\\"!\"\"\" .\n";

  assert(!strcmp(out, expected));
  serd_free(out);

  serd_world_free(world);
}

static size_t
null_sink(const void* const buf,
          const size_t      size,
          const size_t      nmemb,
          void* const       stream)
{
  (void)buf;
  (void)stream;

  return size * nmemb;
}

static void
test_writer_cleanup(void)
{
  SerdStatus       st    = SERD_SUCCESS;
  SerdWorld*       world = serd_world_new();
  SerdEnv*         env   = serd_env_new(serd_empty_string());
  SerdOutputStream output =
    serd_open_output_stream(null_sink, NULL, NULL, NULL);

  SerdWriter* writer =
    serd_writer_new(world, SERD_TURTLE, 0U, env, &output, 1U);

  const SerdSink* sink = serd_writer_sink(writer);

  SerdNode* s = serd_new_uri(serd_string("http://example.org/s"));
  SerdNode* p = serd_new_uri(serd_string("http://example.org/p"));
  SerdNode* o = serd_new_blank(serd_string("start"));

  st = serd_sink_write(sink, SERD_ANON_O, s, p, o, NULL);
  assert(!st);

  // Write the start of several nested anonymous objects
  for (unsigned i = 0U; !st && i < 8U; ++i) {
    char buf[12] = {0};
    snprintf(buf, sizeof(buf), "b%u", i);

    SerdNode* next_o = serd_new_blank(serd_string(buf));

    st = serd_sink_write(sink, SERD_ANON_O, o, p, next_o, NULL);

    serd_node_free(o);
    o = next_o;
  }

  // Finish writing without terminating nodes
  assert(!(st = serd_writer_finish(writer)));

  // Set the base to an empty URI
  SerdNode* empty_uri = serd_new_uri(serd_string(""));
  assert(!(st = serd_sink_write_base(sink, empty_uri)));
  serd_node_free(empty_uri);

  // Free (which could leak if the writer doesn't clean up the stack properly)
  serd_node_free(o);
  serd_node_free(p);
  serd_node_free(s);
  serd_writer_free(writer);
  serd_env_free(env);
  serd_world_free(world);
}

static void
test_strict_write(void)
{
  const char* const path = "serd_strict_write_test.ttl";
  FILE* const       fd   = fopen(path, "wb");
  assert(fd);

  SerdWorld*        world = serd_world_new();
  SerdEnv* const    env   = serd_env_new(serd_empty_string());
  SerdOutputStream  out   = serd_open_output_stream(null_sink, NULL, NULL, fd);
  SerdWriter* const writer =
    serd_writer_new(world, SERD_TURTLE, 0U, env, &out, 1U);

  assert(writer);

  const SerdSink* const sink = serd_writer_sink(writer);

  const uint8_t bad_str[] = {0xFF, 0x90, 'h', 'i', 0};

  SerdNode* s = serd_new_uri(serd_string("http://example.org/s"));
  SerdNode* p = serd_new_uri(serd_string("http://example.org/p"));

  SerdNode* bad_lit = serd_new_string(serd_string((const char*)bad_str));
  SerdNode* bad_uri = serd_new_uri(serd_string((const char*)bad_str));

  assert(serd_sink_write(sink, 0, s, p, bad_lit, NULL) == SERD_BAD_TEXT);
  assert(serd_sink_write(sink, 0, s, p, bad_uri, NULL) == SERD_BAD_TEXT);

  serd_node_free(bad_uri);
  serd_node_free(bad_lit);
  serd_node_free(p);
  serd_node_free(s);
  serd_writer_free(writer);
  serd_env_free(env);
  serd_world_free(world);
  fclose(fd);
  remove(path);
}

// Produce a write error without setting errno
static size_t
error_sink(const void* const buf,
           const size_t      size,
           const size_t      len,
           void* const       stream)
{
  (void)buf;
  (void)size;
  (void)len;
  (void)stream;
  return 0U;
}

static void
test_write_error(void)
{
  SerdWorld* const world = serd_world_new();
  SerdEnv* const   env   = serd_env_new(serd_empty_string());
  SerdOutputStream out = serd_open_output_stream(error_sink, NULL, NULL, NULL);
  SerdWriter*      writer = NULL;
  SerdStatus       st     = SERD_SUCCESS;

  SerdNode* u = serd_new_uri(serd_string("http://example.com/u"));

  writer =
    serd_writer_new(world, SERD_TURTLE, (SerdWriterFlags)0, env, &out, 1U);
  assert(writer);

  const SerdSink* const sink = serd_writer_sink(writer);

  st = serd_sink_write(sink, 0U, u, u, u, NULL);
  assert(st == SERD_BAD_WRITE);
  serd_writer_free(writer);

  serd_node_free(u);
  serd_env_free(env);
  serd_world_free(world);
}

static void
test_writer_stack_overflow(void)
{
  SerdWorld* world = serd_world_new();
  SerdEnv*   env   = serd_env_new(serd_empty_string());

  SerdOutputStream output =
    serd_open_output_stream(null_sink, NULL, NULL, NULL);

  SerdWriter* writer =
    serd_writer_new(world, SERD_TURTLE, 0U, env, &output, 1U);

  const SerdSink* sink = serd_writer_sink(writer);

  SerdNode* const s = serd_new_uri(serd_string("http://example.org/s"));
  SerdNode* const p = serd_new_uri(serd_string("http://example.org/p"));

  SerdNode*  o  = serd_new_blank(serd_string("blank"));
  SerdStatus st = serd_sink_write(sink, SERD_ANON_O, s, p, o, NULL);
  assert(!st);

  // Repeatedly write nested anonymous objects until the writer stack overflows
  for (unsigned i = 0U; i < 512U; ++i) {
    char buf[1024];
    snprintf(buf, sizeof(buf), "b%u", i);

    SerdNode* next_o = serd_new_blank(serd_string(buf));

    st = serd_sink_write(sink, SERD_ANON_O, o, p, next_o, NULL);

    serd_node_free(o);
    o = next_o;

    if (st) {
      assert(st == SERD_BAD_STACK);
      break;
    }
  }

  assert(st == SERD_BAD_STACK);

  serd_node_free(o);
  serd_node_free(p);
  serd_node_free(s);
  serd_writer_free(writer);
  serd_close_output(&output);
  serd_env_free(env);
  serd_world_free(world);
}

static void
test_write_empty_syntax(void)
{
  SerdWorld* world = serd_world_new();
  SerdEnv*   env   = serd_env_new(serd_empty_string());

  SerdNode* s = serd_new_uri(serd_string("http://example.org/s"));
  SerdNode* p = serd_new_uri(serd_string("http://example.org/p"));
  SerdNode* o = serd_new_curie(serd_string("eg:o"));

  SerdBuffer       buffer = {NULL, 0};
  SerdOutputStream output = serd_open_output_buffer(&buffer);

  SerdWriter* writer =
    serd_writer_new(world, SERD_SYNTAX_EMPTY, 0U, env, &output, 1U);

  assert(writer);

  assert(!serd_sink_write(serd_writer_sink(writer), 0U, s, p, o, NULL));
  assert(!serd_close_output(&output));

  char* const out = (char*)buffer.buf;
  assert(out);
  assert(strlen(out) == 0);
  serd_free(out);

  serd_writer_free(writer);
  serd_node_free(o);
  serd_node_free(p);
  serd_node_free(s);
  serd_close_output(&output);
  serd_env_free(env);
  serd_world_free(world);
}

static void
check_pname_escape(const char* const lname, const char* const expected)
{
  SerdWorld*       world  = serd_world_new();
  SerdEnv*         env    = serd_env_new(serd_empty_string());
  SerdBuffer       buffer = {NULL, 0};
  SerdOutputStream output = serd_open_output_buffer(&buffer);

  SerdWriter* writer =
    serd_writer_new(world, SERD_TURTLE, 0U, env, &output, 1U);
  assert(writer);

  static const char* const prefix     = "http://example.org/";
  const size_t             prefix_len = strlen(prefix);

  serd_env_set_prefix(env, serd_string("eg"), serd_string(prefix));

  SerdNode* s = serd_new_uri(serd_string("http://example.org/s"));
  SerdNode* p = serd_new_uri(serd_string("http://example.org/p"));

  char* const uri = (char*)calloc(1, prefix_len + strlen(lname) + 1);
  memcpy(uri, prefix, prefix_len + 1);
  memcpy(uri + prefix_len, lname, strlen(lname) + 1);

  SerdNode* node = serd_new_uri(serd_string(uri));
  assert(!serd_sink_write(serd_writer_sink(writer), 0, s, p, node, NULL));
  serd_node_free(node);

  free(uri);
  serd_node_free(p);
  serd_node_free(s);
  serd_writer_free(writer);
  serd_close_output(&output);
  serd_env_free(env);
  serd_buffer_close(&buffer);

  char* const out = (char*)buffer.buf;
  assert(!strcmp(out, expected));
  serd_free(out);

  serd_world_free(world);
}

static void
test_write_pname_escapes(void)
{
  // Check that '.' is escaped only at the start and end
  check_pname_escape(".xyz", "eg:s\n\teg:p eg:\\.xyz .\n");
  check_pname_escape("w.yz", "eg:s\n\teg:p eg:w.yz .\n");
  check_pname_escape("wx.z", "eg:s\n\teg:p eg:wx.z .\n");
  check_pname_escape("wxy.", "eg:s\n\teg:p eg:wxy\\. .\n");

  // Check that ':' is not escaped anywhere
  check_pname_escape(":xyz", "eg:s\n\teg:p eg::xyz .\n");
  check_pname_escape("w:yz", "eg:s\n\teg:p eg:w:yz .\n");
  check_pname_escape("wx:z", "eg:s\n\teg:p eg:wx:z .\n");
  check_pname_escape("wxy:", "eg:s\n\teg:p eg:wxy: .\n");

  // Check that special characters like '~' are escaped everywhere
  check_pname_escape("~xyz", "eg:s\n\teg:p eg:\\~xyz .\n");
  check_pname_escape("w~yz", "eg:s\n\teg:p eg:w\\~yz .\n");
  check_pname_escape("wx~z", "eg:s\n\teg:p eg:wx\\~z .\n");
  check_pname_escape("wxy~", "eg:s\n\teg:p eg:wxy\\~ .\n");

  // Check that out of range multi-byte characters are escaped everywhere
  static const char first_escape[] = {(char)0xC3U, (char)0xB7U, 'y', 'z', 0};
  static const char mid_escape[]   = {'w', (char)0xC3U, (char)0xB7U, 'z', 0};
  static const char last_escape[]  = {'w', 'x', (char)0xC3U, (char)0xB7U, 0};

  check_pname_escape((const char*)first_escape, "eg:s\n\teg:p eg:%C3%B7yz .\n");
  check_pname_escape((const char*)mid_escape, "eg:s\n\teg:p eg:w%C3%B7z .\n");
  check_pname_escape((const char*)last_escape, "eg:s\n\teg:p eg:wx%C3%B7 .\n");
}

int
main(void)
{
  test_writer_new();
  test_write_bad_event();
  test_write_bad_prefix();
  test_write_long_literal();
  test_writer_cleanup();
  test_strict_write();
  test_write_error();
  test_writer_stack_overflow();
  test_write_empty_syntax();
  test_write_pname_escapes();

  return 0;
}