// Copyright 2018-2022 David Robillard <d@drobilla.net>
// SPDX-License-Identifier: GPL-2.0-or-later

#undef NDEBUG

#include "test_utils.hpp"

#include "chilbert/BoundedBitVec.hpp"
#include "chilbert/DynamicBitVec.hpp"
#include "chilbert/SmallBitVec.hpp"
#include "chilbert/StaticBitVec.hpp"
#include "chilbert/chilbert.hpp"

#if defined(__clang__)
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wcovered-switch-default\"")
_Pragma("clang diagnostic ignored \"-Wdouble-promotion\"")
_Pragma("clang diagnostic ignored \"-Wreserved-id-macro\"")
_Pragma("clang diagnostic ignored \"-Wsign-conversion\"")
_Pragma("clang diagnostic ignored \"-Wzero-as-null-pointer-constant\"")
#endif

#include <gmp.h>
#include <gmpxx.h>

#if defined(__clang__)
_Pragma("clang diagnostic pop")
#endif

#include <array>
#include <cassert>
#include <climits>
#include <cstddef>
#include <cstdint>

/// Return a `D`-dimensional point within `ms` per-dimension precision
template<size_t D>
std::array<uint64_t, D>
make_random_point(Context& ctx, const std::array<size_t, D>& ms)
{
  std::array<uint64_t, D> p{};
  for (size_t i = 0; i < D; ++i) {
    p[i] = rand_between(ctx, 0, (1UL << ms[i]) - 1);
  }
  return p;
}

/// Return the squared distance from point `a` to point `b`
template<class T, size_t D>
T
squared_distance(const std::array<T, D>& a, const std::array<T, D>& b)
{
  T sdist = 0;
  for (size_t i = 0; i < D; ++i) {
    const T diff = a[i] > b[i] ? a[i] - b[i] : b[i] - a[i];
    sdist += diff * diff;
  }
  return sdist;
}

/// Convert bit vector `vec` to a big integer
template<class T>
mpz_class
to_big_int(const T& vec)
{
  using Rack = typename T::Rack;

  mpz_t ia;
  mpz_init(ia);
  mpz_import(ia, vec.num_racks(), -1, sizeof(Rack), 0, 0, vec.data());

  mpz_class num{ia};
  mpz_clear(ia);
  return num;
}

/// Convert big integer `num` to a bit vector
template<class T, size_t M>
T
from_big_int(const mpz_class& num)
{
  using Rack = typename T::Rack;

  T      vec   = make_zero_bitvec<T, M>();
  size_t count = 0;
  mpz_export(vec.data(), &count, -1, sizeof(Rack), 0, 0, num.get_mpz_t());
  assert(count <= static_cast<size_t>(vec.num_racks()));
  return vec;
}

template<class H, size_t M, size_t D>
void
test_standard(Context& ctx)
{
  static_assert(M < sizeof(typename H::Rack) * CHAR_BIT, "");

  // Generate random point and its hilbert index
  const auto pa = make_random_point<M, D>(ctx);

  H ha = make_zero_bitvec<H, D * M>();
  assert(ha.size() >= D * M);
  chilbert::coords_to_index(pa, M, D, ha);

  {
    // Ensure unmapping results in the original point
    auto pa_out = make_random_point<M, D>(ctx);
    chilbert::index_to_coords(pa_out, M, D, ha);
    assert(pa_out == pa);
  }

  // Convert hilbert indices to a big integer for manipulation/comparison
  const auto ia = to_big_int(ha);

  // Generate the next hilbert index
  const auto ib = ia + 1;
  const auto hb = from_big_int<H, D * M>(ib);

  // Unmap next hilbert index to a point
  auto pb = make_random_point<M, D>(ctx);
  chilbert::index_to_coords(pb, M, D, hb);

  // Ensure next point is 1 unit of distance away from first
  assert(squared_distance(pa, pb) == 1);
}

template<class T, size_t M, size_t D>
void
test_compact(Context& ctx)
{
  static_assert(M < sizeof(typename T::Rack) * CHAR_BIT, "");

  // Generate random point and its hilbert index
  const auto ms = make_random_precisions<D * M, D>(ctx);
  const auto pa = make_random_point<D>(ctx, ms);

  T ha = make_zero_bitvec<T, D * M>();
  assert(ha.size() >= D * M);
  chilbert::coords_to_compact_index(pa, ms.data(), D, ha);

  {
    // Ensure unmapping results in the original point
    auto pa_out = make_random_point<M, D>(ctx);
    chilbert::compact_index_to_coords(pa_out, ms.data(), D, ha);
    assert(pa_out == pa);
  }

  // Convert hilbert indices to a big integer for manipulation/comparison
  const auto ia = to_big_int(ha);

  // Generate the next hilbert index
  const auto ib = ia + 1;
  const auto hb = from_big_int<T, D * M>(ib);

  // Unmap next hilbert index to a point
  auto pb = make_random_point<M, D>(ctx);
  chilbert::compact_index_to_coords(pb, ms.data(), D, hb);

  // Ensure next point is 1 unit of distance away from first
  assert(squared_distance(pa, pb) == 1);
}

int
main()
{
  Context ctx;

  test_standard<chilbert::SmallBitVec, 4, 2>(ctx);
  test_standard<chilbert::SmallBitVec, 32, 2>(ctx);
  test_standard<chilbert::SmallBitVec, 16, 4>(ctx);
  test_standard<chilbert::SmallBitVec, 8, 8>(ctx);
  test_standard<chilbert::SmallBitVec, 4, 16>(ctx);
  test_standard<chilbert::SmallBitVec, 2, 32>(ctx);
  test_standard<chilbert::SmallBitVec, 1, 64>(ctx);

  test_standard<chilbert::DynamicBitVec, 4, 65>(ctx);
  test_standard<chilbert::DynamicBitVec, 32, 64>(ctx);
  test_standard<chilbert::DynamicBitVec, 63, 128>(ctx);

  test_standard<chilbert::StaticBitVec<4 * 2>, 4, 2>(ctx);
  test_standard<chilbert::StaticBitVec<32 * 2>, 32, 2>(ctx);
  test_standard<chilbert::StaticBitVec<16 * 4>, 16, 4>(ctx);
  test_standard<chilbert::StaticBitVec<8 * 8>, 8, 8>(ctx);
  test_standard<chilbert::StaticBitVec<4 * 16>, 4, 16>(ctx);
  test_standard<chilbert::StaticBitVec<2 * 32>, 2, 32>(ctx);
  test_standard<chilbert::StaticBitVec<1 * 64>, 1, 64>(ctx);
  test_standard<chilbert::StaticBitVec<4 * 65>, 4, 65>(ctx);
  test_standard<chilbert::StaticBitVec<32 * 64>, 32, 64>(ctx);
  test_standard<chilbert::StaticBitVec<63 * 128>, 63, 128>(ctx);

  test_standard<chilbert::BoundedBitVec<4 * 2>, 4, 2>(ctx);
  test_standard<chilbert::BoundedBitVec<32 * 2>, 32, 2>(ctx);
  test_standard<chilbert::BoundedBitVec<16 * 4>, 16, 4>(ctx);
  test_standard<chilbert::BoundedBitVec<8 * 8>, 8, 8>(ctx);
  test_standard<chilbert::BoundedBitVec<4 * 16>, 4, 16>(ctx);
  test_standard<chilbert::BoundedBitVec<2 * 32>, 2, 32>(ctx);
  test_standard<chilbert::BoundedBitVec<1 * 64>, 1, 64>(ctx);
  test_standard<chilbert::BoundedBitVec<4 * 128>, 4, 65>(ctx);
  test_standard<chilbert::BoundedBitVec<32 * 128>, 32, 64>(ctx);
  test_standard<chilbert::BoundedBitVec<63 * 128>, 63, 128>(ctx);

  test_compact<chilbert::SmallBitVec, 4, 2>(ctx);
  test_compact<chilbert::SmallBitVec, 32, 2>(ctx);
  test_compact<chilbert::SmallBitVec, 16, 4>(ctx);
  test_compact<chilbert::SmallBitVec, 8, 8>(ctx);
  test_compact<chilbert::SmallBitVec, 4, 16>(ctx);
  test_compact<chilbert::SmallBitVec, 2, 32>(ctx);
  test_compact<chilbert::SmallBitVec, 1, 64>(ctx);

  test_compact<chilbert::DynamicBitVec, 4, 65>(ctx);
  test_compact<chilbert::DynamicBitVec, 32, 64>(ctx);
  test_compact<chilbert::DynamicBitVec, 63, 128>(ctx);

  test_compact<chilbert::StaticBitVec<4 * 2>, 4, 2>(ctx);
  test_compact<chilbert::StaticBitVec<32 * 2>, 32, 2>(ctx);
  test_compact<chilbert::StaticBitVec<16 * 4>, 16, 4>(ctx);
  test_compact<chilbert::StaticBitVec<8 * 8>, 8, 8>(ctx);
  test_compact<chilbert::StaticBitVec<4 * 16>, 4, 16>(ctx);
  test_compact<chilbert::StaticBitVec<2 * 32>, 2, 32>(ctx);
  test_compact<chilbert::StaticBitVec<1 * 64>, 1, 64>(ctx);
  test_compact<chilbert::StaticBitVec<4 * 65>, 4, 65>(ctx);
  test_compact<chilbert::StaticBitVec<32 * 64>, 32, 64>(ctx);
  test_compact<chilbert::StaticBitVec<63 * 128>, 63, 128>(ctx);

  return 0;
}