// Copyright 2018-2022 David Robillard <d@drobilla.net>
// SPDX-License-Identifier: GPL-2.0-or-later

#undef NDEBUG

#include "test_utils.hpp"

#include "chilbert/BoundedBitVec.hpp"
#include "chilbert/DynamicBitVec.hpp"
#include "chilbert/SmallBitVec.hpp"
#include "chilbert/StaticBitVec.hpp"
#include "chilbert/detail/gray_code_rank.hpp"
#include "chilbert/detail/operations.hpp"
#include "chilbert/operators.hpp"

#include <algorithm>
#include <array>
#include <cassert>
#include <cstddef>
#include <iterator>

namespace {

template<class T, size_t Max, size_t D>
T
get_mask(const std::array<size_t, D>& ms, const size_t d, const size_t step)
{
  T      mask = make_zero_bitvec<T, D>();
  size_t b    = 0U;
  chilbert::detail::extract_mask(ms.data(), D, d, step, mask, b);

  assert(b == mask.count());

  return mask;
}

template<class T, size_t Max, size_t D>
void
test_extract_mask(Context& ctx)
{
  for (size_t d = 0; d < D; ++d) {
    const auto   ms    = make_random_precisions<Max, D>(ctx);
    const auto   max_m = *std::max_element(std::begin(ms), std::end(ms));
    const size_t step  = rand_between(ctx, 0, max_m);
    const auto   mask  = get_mask<T, D>(ms, d, step);

    for (size_t i = 0; i < D; ++i) {
      assert(mask.test(i) == (ms[(d + i) % D] > step));
    }
  }
}

template<class T, size_t Max, size_t D>
void
test_gray_code_rank(Context& ctx)
{
  for (size_t d = 0; d < D; ++d) {
    // Generate random mask
    const auto mask = make_random_bitvec<T, D>(ctx);

    // Generate two random values and their gray codes
    const auto a = make_random_bitvec<T, D>(ctx);
    const auto b = make_random_bitvec<T, D>(ctx);

    auto ga = a;
    chilbert::detail::gray_code(ga);

    auto gb = b;
    chilbert::detail::gray_code(gb);

    // Calculate gray code ranks
    auto ra = make_zero_bitvec<T, D>();
    chilbert::detail::gray_code_rank(mask, ga, D, ra);

    auto rb = make_zero_bitvec<T, D>();
    chilbert::detail::gray_code_rank(mask, gb, D, rb);

    // Ensure ranks have at most mask.count() bits
    auto max = make_zero_bitvec<T, D>();
    chilbert::detail::set_bit(max, mask.count(), 1);
    assert(ra < max);
    assert(rb < max);

    // Test fundamental property of gray code ranks
    const auto mga = ga & mask;
    const auto mgb = gb & mask;
    assert((mga < mgb) == (ra < rb));

    // Test inversion
    const auto pat     = ~mask;
    auto       ga_out  = make_zero_bitvec<T, D>();
    auto       gag_out = make_zero_bitvec<T, D>();
    chilbert::detail::gray_code_rank_inv(
      mask, pat, ra, D, mask.count(), gag_out, ga_out);
    assert((ga_out & mask) == (ga & mask));

    auto gag_check = ga_out;
    chilbert::detail::gray_code(gag_check);
    assert(gag_check == gag_out);
  }
}

template<class T, size_t Max, size_t D>
void
test(Context& ctx)
{
  test_extract_mask<T, Max, D>(ctx);
  test_gray_code_rank<T, Max, D>(ctx);
}

} // namespace

int
main()
{
  Context ctx;

  test<chilbert::SmallBitVec, 64, 1>(ctx);
  test<chilbert::SmallBitVec, 64, 31>(ctx);
  test<chilbert::SmallBitVec, 64, 32>(ctx);
  test<chilbert::SmallBitVec, 64, 33>(ctx);
  test<chilbert::SmallBitVec, 64, 60>(ctx);
  test<chilbert::SmallBitVec, 64, 64>(ctx);

  test<chilbert::DynamicBitVec, 64, 1>(ctx);
  test<chilbert::DynamicBitVec, 64, 31>(ctx);
  test<chilbert::DynamicBitVec, 64, 32>(ctx);
  test<chilbert::DynamicBitVec, 64, 33>(ctx);
  test<chilbert::DynamicBitVec, 64, 60>(ctx);
  test<chilbert::DynamicBitVec, 64, 64>(ctx);
  test<chilbert::DynamicBitVec, 96, 65>(ctx);
  test<chilbert::DynamicBitVec, 1024, 997>(ctx);

  test<chilbert::StaticBitVec<1>, 64, 1>(ctx);
  test<chilbert::StaticBitVec<31>, 64, 31>(ctx);
  test<chilbert::StaticBitVec<32>, 64, 32>(ctx);
  test<chilbert::StaticBitVec<33>, 64, 33>(ctx);
  test<chilbert::StaticBitVec<60>, 64, 60>(ctx);
  test<chilbert::StaticBitVec<64>, 64, 64>(ctx);
  test<chilbert::StaticBitVec<65>, 96, 65>(ctx);
  test<chilbert::StaticBitVec<997>, 1024, 997>(ctx);

  test<chilbert::BoundedBitVec<1>, 64, 1>(ctx);
  test<chilbert::BoundedBitVec<31>, 64, 31>(ctx);
  test<chilbert::BoundedBitVec<32>, 64, 32>(ctx);
  test<chilbert::BoundedBitVec<64>, 64, 33>(ctx);
  test<chilbert::BoundedBitVec<64>, 64, 60>(ctx);
  test<chilbert::BoundedBitVec<64>, 64, 64>(ctx);
  test<chilbert::BoundedBitVec<128>, 96, 65>(ctx);
  test<chilbert::BoundedBitVec<997>, 1024, 997>(ctx);
  test<chilbert::BoundedBitVec<2048>, 1024, 997>(ctx);

  return 0;
}