/*
  Copyright (C) 2018 David Robillard <d@drobilla.net>

  This program is free software: you can redistribute it and/or modify it under
  the terms of the GNU General Public License as published by the Free Software
  Foundation, either version 2 of the License, or (at your option) any later
  version.

  This program is distributed in the hope that it will be useful, but WITHOUT
  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
  FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
  details.

  You should have received a copy of the GNU General Public License along with
  this program.  If not, see <https://www.gnu.org/licenses/>.
*/

#undef NDEBUG

#include "test_utils.hpp"

#include "chilbert/BoundedBitVec.hpp"
#include "chilbert/DynamicBitVec.hpp"
#include "chilbert/SmallBitVec.hpp"
#include "chilbert/StaticBitVec.hpp"

#include <cassert>
#include <cstddef>

template <class T, size_t N>
void
test_and(Context& ctx)
{
	const T a = make_random_bitvec<T, N>(ctx);
	const T b = make_random_bitvec<T, N>(ctx);
	T       r = a;
	assert((a & b) == (r &= b));

	for (size_t i = 0; i < N; ++i) {
		assert(r.test(i) == (a.test(i) && b.test(i)));
	}
}

template <class T, size_t N>
void
test_or(Context& ctx)
{
	const T a = make_random_bitvec<T, N>(ctx);
	const T b = make_random_bitvec<T, N>(ctx);
	T       r = a;
	assert((a | b) == (r |= b));

	for (size_t i = 0; i < N; ++i) {
		assert(r.test(i) == (a.test(i) || b.test(i)));
	}
}

template <class T, size_t N>
void
test_xor(Context& ctx)
{
	const T a = make_random_bitvec<T, N>(ctx);
	const T b = make_random_bitvec<T, N>(ctx);
	T       r = a;
	assert((a ^ b) == (r ^= b));

	for (size_t i = 0; i < N; ++i) {
		assert(r.test(i) == (a.test(i) != b.test(i)));
	}
}

template <class T, size_t N>
void
test_not(Context& ctx)
{
	const T v = make_random_bitvec<T, N>(ctx);
	const T r = ~v;

	for (size_t i = 0; i < N; ++i) {
		assert(r.test(i) == !v.test(i));
	}
}

template <class T, size_t N>
void
test_flip_one(Context&)
{
	T v = make_zero_bitvec<T, N>();
	for (size_t i = 0; i < N; ++i) {
		assert(v.none());
		v.flip(i);
		for (size_t j = 0; j < N; ++j) {
			assert(v.test(j) == (j == i));
		}

		v.flip(i);
		assert(v.none());
	}
}

template <class T, size_t N>
void
test_flip_all(Context& ctx)
{
	const T a = make_random_bitvec<T, N>(ctx);
	T       r = a;
	r.flip();
	for (size_t i = 0; i < N; ++i) {
		assert(r.test(i) == !a.test(i));
	}
}

template <class T, size_t N>
void
test_none(Context&)
{
	T v = make_zero_bitvec<T, N>();
	assert(v.none());
	v.set();
	assert(v.none() == (N == 0));
	if (N > 1) {
		v.reset(N / 2);
		assert(!v.none());
		v.reset();
		v.set(N / 2);
		assert(!v.none());
	}
}

template <class T, size_t N>
void
test_set_reset_one(Context&)
{
	T v = make_zero_bitvec<T, N>();
	for (size_t i = 0; i < N; ++i) {
		assert(v.none());
		v.set(i);
		for (size_t j = 0; j < N; ++j) {
			assert(v.test(j) == (j == i));
		}

		v.reset(i);
		assert(v.none());
	}
}

template <class T, size_t N>
void
test_set_all(Context&)
{
	T v = make_zero_bitvec<T, N>();
	v.set();
	for (size_t i = 0; i < N; ++i) {
		assert(v.test(i));
	}
}

template <class T, size_t N>
void
test_reset_all(Context&)
{
	T v = make_zero_bitvec<T, N>();
	v.set();
	v.reset();
	for (size_t i = 0; i < N; ++i) {
		assert(!v.test(i));
	}
}

template <class T, size_t N>
void
test_left_shift(Context& ctx)
{
	for (size_t s = 0; s < N; ++s) {
		const T v = make_random_bitvec<T, N>(ctx);
		T       r = v;
		assert((v << s) == (r <<= s));

		for (size_t i = s; i < N - s; ++i) {
			assert(r.test(i + s) == v.test(i));
		}
	}
}

template <class T, size_t N>
void
test_right_shift(Context& ctx)
{
	for (size_t s = 0; s < N; ++s) {
		const T v = make_random_bitvec<T, N>(ctx);
		T       r = v;
		assert((v >> s) == (r >>= s));

		for (size_t i = s; i < N - s; ++i) {
			assert(r.test(i - s) == v.test(i));
		}
	}
}

template <class T, size_t N>
void
test_left_rotate(Context& ctx)
{
	const T v = make_random_bitvec<T, N>(ctx);
	for (size_t bits = 0; bits <= N; ++bits) {
		T r = v;
		r.rotl(bits);

		if (N > 0) {
			for (size_t i = 0; i < N; ++i) {
				assert(r.test((i + bits) % N) == v.test(i));
			}
		}
	}
}

template <class T, size_t N>
void
test_right_rotate(Context& ctx)
{
	const T v = make_random_bitvec<T, N>(ctx);
	for (size_t bits = 0; bits <= N; ++bits) {
		T r = v;
		r.rotr(bits);

		if (N > 0) {
			for (size_t i = 0; i < N; ++i) {
				assert(r.test(i) == v.test((i + bits) % N));
			}
		}
	}
}

template <class T, size_t N>
void
test_find_first(Context&)
{
	T v = make_zero_bitvec<T, N>();
	for (size_t i = 0; i < N; ++i) {
		v.reset();
		v.set(i);
		for (size_t j = i + 1; j < N; ++j) {
			v.set(j, rand() & 1);
		}
		assert(size_t(v.find_first()) == i + 1);
	}
}

template <class T, size_t N>
void
test_gray_code(Context& ctx)
{
	const T v = make_random_bitvec<T, N>(ctx);
	T       r = v;
	chilbert::detail::gray_code(r);

	if (N > 0) {
		assert(N == 1 || r == (v ^ (v >> 1)));

		T s = r;
		chilbert::detail::gray_code_inv(s);
		assert(s == v);
	}
}

template <class T, size_t N>
void
test_comparison(Context&)
{
	T a = make_zero_bitvec<T, N>();
	T b = make_zero_bitvec<T, N>();

	for (size_t bit = 1; bit < N; ++bit) {
		chilbert::detail::set_bit(a, bit, 1);

		for (size_t i = 0; i < bit; ++i) {
			chilbert::detail::set_bit(a, i, rand() % 2 == 0);
			chilbert::detail::set_bit(b, i, rand() % 2 == 0);
		}

		assert(b < a);
	}
}

template <class T, size_t N>
void
test_iteration(Context&)
{
	T      v     = make_zero_bitvec<T, N>();
	size_t count = 0;
	for (const auto bit : v) {
		assert(!bit);
		++count;
	}
	// assert(count == N);

	v.flip();
	count = 0;
	for (const auto bit : v) {
		assert(bit);
		++count;
	}
	// assert(count == N);
}

template <class T, size_t N>
void
test(Context& ctx)
{
	test_and<T, N>(ctx);
	test_or<T, N>(ctx);
	test_xor<T, N>(ctx);
	test_not<T, N>(ctx);
	test_flip_one<T, N>(ctx);
	test_flip_all<T, N>(ctx);
	test_none<T, N>(ctx);
	test_set_reset_one<T, N>(ctx);
	test_set_all<T, N>(ctx);
	test_reset_all<T, N>(ctx);
	test_left_shift<T, N>(ctx);
	test_right_shift<T, N>(ctx);
	test_left_rotate<T, N>(ctx);
	test_right_rotate<T, N>(ctx);
	test_find_first<T, N>(ctx);
	test_gray_code<T, N>(ctx);
	test_comparison<T, N>(ctx);
	test_iteration<T, N>(ctx);
}

int
main()
{
	Context ctx;

	// test<chilbert::SmallBitVec, 0>(ctx);
	test<chilbert::SmallBitVec, 1>(ctx);
	test<chilbert::SmallBitVec, 31>(ctx);
	test<chilbert::SmallBitVec, 32>(ctx);
	test<chilbert::SmallBitVec, 33>(ctx);
	test<chilbert::SmallBitVec, 63>(ctx);
	test<chilbert::SmallBitVec, 64>(ctx);

	test<chilbert::DynamicBitVec, 0>(ctx);
	test<chilbert::DynamicBitVec, 1>(ctx);
	test<chilbert::DynamicBitVec, 31>(ctx);
	test<chilbert::DynamicBitVec, 32>(ctx);
	test<chilbert::DynamicBitVec, 33>(ctx);
	test<chilbert::DynamicBitVec, 63>(ctx);
	test<chilbert::DynamicBitVec, 64>(ctx);
	test<chilbert::DynamicBitVec, 65>(ctx);
	test<chilbert::DynamicBitVec, 997>(ctx);

	test<chilbert::StaticBitVec<0>, 0>(ctx);
	test<chilbert::StaticBitVec<1>, 1>(ctx);
	test<chilbert::StaticBitVec<31>, 31>(ctx);
	test<chilbert::StaticBitVec<32>, 32>(ctx);
	test<chilbert::StaticBitVec<33>, 33>(ctx);
	test<chilbert::StaticBitVec<63>, 63>(ctx);
	test<chilbert::StaticBitVec<64>, 64>(ctx);
	test<chilbert::StaticBitVec<65>, 65>(ctx);
	test<chilbert::StaticBitVec<997>, 997>(ctx);

	test<chilbert::BoundedBitVec<0>, 0>(ctx);
	test<chilbert::BoundedBitVec<1>, 1>(ctx);
	test<chilbert::BoundedBitVec<31>, 31>(ctx);
	test<chilbert::BoundedBitVec<32>, 32>(ctx);
	test<chilbert::BoundedBitVec<64>, 33>(ctx);
	test<chilbert::BoundedBitVec<64>, 63>(ctx);
	test<chilbert::BoundedBitVec<64>, 64>(ctx);
	test<chilbert::BoundedBitVec<128>, 65>(ctx);
	test<chilbert::BoundedBitVec<997>, 997>(ctx);
	test<chilbert::BoundedBitVec<2048>, 997>(ctx);

	return 0;
}