// Copyright 2018-2022 David Robillard <d@drobilla.net>
// Copyright 2006-2007 Chris Hamilton <chamilton@cs.dal.ca>
// SPDX-License-Identifier: GPL-2.0-or-later

#ifndef CHILBERT_DETAIL_MULTIBITVEC_HPP
#define CHILBERT_DETAIL_MULTIBITVEC_HPP

#include "chilbert/detail/BitVecIndex.hpp"
#include "chilbert/detail/BitVecIterator.hpp"
#include "chilbert/detail/BitVecMask.hpp"
#include "chilbert/detail/operations.hpp"

#include <array>
#include <cassert>
#include <climits>
#include <cstdint>
#include <cstring>

namespace chilbert {
namespace detail {

template<class Derived>
class MultiBitVec
{
public:
  using Rack = uintptr_t;
  using Mask = BitVecMask<Rack>;

  using iterator       = BitVecIterator<MultiBitVec<Derived>>;
  using const_iterator = ConstBitVecIterator<MultiBitVec<Derived>>;

  static constexpr size_t bits_per_rack = sizeof(Rack) * CHAR_BIT;

  /// Return the value of the bit covered by `mask`
  bool test(const Mask mask) const { return rack(mask.rack) & mask.mask; }

  /// Return the value of the `index`th bit
  bool test(const size_t index) const { return test(mask(index)); }

  /// Set all bits to one
  Derived& set()
  {
    if (size()) {
      memset(data(), 0xFF, data_size());
      self()->truncate();
    }

    return *self();
  }

  /// Set the bit covered by `mask` to 1
  Derived& set(const Mask mask)
  {
    rack(mask.rack) |= mask.mask;
    return *self();
  }

  /// Set the `index`th bit to 1
  Derived& set(const size_t index) { return set(mask(index)); }

  /// Set the bit covered by `mask` to `value`
  Derived& set(const Mask mask, const bool value)
  {
    auto& r = rack(mask.rack);
    r ^= (-Rack{value} ^ r) & mask.mask;
    return *self();
  }

  /// Set the `index`th bit to `value`
  Derived& set(const size_t index, const bool value)
  {
    return set(mask(index), value);
  }

  /// Set all bits to zero
  Derived& reset()
  {
    memset(data(), 0, data_size());
    return *self();
  }

  /// Reset the bit covered by `mask` to 0
  Derived& reset(const Mask mask)
  {
    rack(mask.rack) &= ~mask.mask;
    return *self();
  }

  /// Reset the `index`th bit to 0
  Derived& reset(const size_t index) { return reset(mask(index)); }

  /// Flip all bits (one's complement)
  Derived& flip()
  {
    for (size_t i = 0; i < num_racks(); ++i) {
      rack(i) = ~rack(i);
    }
    return *self();
  }

  /// Flip the value of the bit covered by `mask`
  Derived& flip(const Mask mask)
  {
    rack(mask.rack) ^= mask.mask;
    return *self();
  }

  /// Flip the value of the `index`th bit
  Derived& flip(const size_t index) { return flip(mask(index)); }

  /// Clear any bits in storage outside the valid range if necessary
  void truncate()
  {
    if (const auto pad = num_racks() * bits_per_rack - size()) {
      rack(num_racks() - 1) &= ~Rack{0} >> pad;
    }
  }

  /// Right-rotate by `bits` positions
  Derived& rotr(const size_t bits)
  {
    assert(bits <= size());
    if (bits == 0 || bits == size()) {
      return *self();
    }

    Derived t1(*self());
    *self() >>= bits;
    t1 <<= (size() - bits);
    *self() |= t1;

    truncate();
    return *self();
  }

  /// Left-rotate by `bits` positions
  Derived& rotl(const size_t bits)
  {
    assert(bits <= size());
    if (bits == 0 || bits == size()) {
      return *self();
    }

    Derived t1(*self());
    *self() <<= bits;
    t1 >>= (size() - bits);
    *self() |= t1;

    truncate();
    return *self();
  }

  /// Return true iff all bits are zero
  bool none() const
  {
    for (size_t i = 0; i < num_racks(); ++i) {
      if (rack(i)) {
        return false;
      }
    }
    return true;
  }

  /// Return 1 + the index of the first set bit, or 0 if there are none
  size_t find_first() const
  {
    for (size_t i = 0; i < num_racks(); ++i) {
      const int j = chilbert::detail::find_first(rack(i));
      if (j) {
        return (i * bits_per_rack) + static_cast<size_t>(j);
      }
    }
    return 0;
  }

  /// Return the number of set bits
  size_t count() const
  {
    size_t c = 0;
    for (size_t i = 0; i < num_racks(); ++i) {
      c += static_cast<size_t>(pop_count(rack(i)));
    }
    return c;
  }

  /// Return a mask that covers the bit with index `i`
  Mask mask(const size_t i = 0) const
  {
    assert(i <= size());
    return Mask{i};
  }

  bool operator==(const Derived& vec) const
  {
    return (num_racks() == vec.num_racks() &&
            (num_racks() == 0 ||
             !memcmp(data(), vec.data(), num_racks() * sizeof(Rack))));
  }

  bool operator!=(const Derived& vec) const { return !(*this == vec); }

  bool operator<(const Derived& vec) const
  {
    assert(size() == vec.size());

    for (size_t ri = 0; ri < num_racks(); ++ri) {
      const size_t i = num_racks() - ri - 1;
      if (rack(i) < vec.rack(i)) {
        return true;
      }

      if (rack(i) > vec.rack(i)) {
        return false;
      }
    }
    return false;
  }

  Derived& operator&=(const Derived& vec)
  {
    for (size_t i = 0; i < std::min(num_racks(), vec.num_racks()); ++i) {
      rack(i) &= vec.rack(i);
    }

    return *self();
  }

  Derived& operator|=(const Derived& vec)
  {
    for (size_t i = 0; i < std::min(num_racks(), vec.num_racks()); ++i) {
      rack(i) |= vec.rack(i);
    }

    return *self();
  }

  Derived& operator^=(const Derived& vec)
  {
    for (size_t i = 0; i < std::min(num_racks(), vec.num_racks()); ++i) {
      rack(i) ^= vec.rack(i);
    }

    return *self();
  }

  Derived& operator<<=(const size_t bits)
  {
    if (bits == 0) {
      return *self();
    }

    if (bits >= size()) {
      reset();
      return *self();
    }

    const Index index{bits};

    if (index.bit == 0) {
      // Simple rack-aligned shift
      for (size_t i = num_racks() - 1; i >= index.rack; --i) {
        rack(i) = rack(i - index.rack);
      }
    } else {
      // Rack + bit offset shift
      const size_t right_shift = bits_per_rack - index.bit;
      for (size_t i = num_racks() - index.rack - 1; i > 0; --i) {
        rack(i + index.rack) =
          (rack(i) << index.bit) | (rack(i - 1) >> right_shift);
      }

      rack(index.rack) = rack(0) << index.bit;
    }

    // Zero least significant racks
    for (size_t i = 0; i < index.rack; ++i) {
      rack(i) = 0;
    }

    return *self();
  }

  Derived& operator>>=(const size_t bits)
  {
    if (bits == 0) {
      return *self();
    }

    if (bits >= size()) {
      reset();
      return *self();
    }

    const Index index{bits};

    if (index.bit == 0) {
      // Simple rack-aligned shift
      for (size_t i = 0; i < num_racks() - index.rack; ++i) {
        rack(i) = rack(i + index.rack);
      }
    } else {
      // Rack + bit offset shift
      const size_t last       = num_racks() - 1;
      const size_t left_shift = bits_per_rack - index.bit;
      for (size_t i = index.rack; i < last; ++i) {
        rack(i - index.rack) =
          (rack(i) >> index.bit) | (rack(i + 1) << left_shift);
      }

      rack(last - index.rack) = rack(last) >> index.bit;
    }

    // Zero most significant racks
    for (size_t i = num_racks() - index.rack; i < num_racks(); ++i) {
      rack(i) = 0;
    }

    return *self();
  }

  auto begin(const size_t i = 0) { return iterator(*self(), i); }
  auto end() { return iterator(*self(), size()); }
  auto begin(const size_t i = 0) const { return const_iterator(*self(), i); }
  auto end() const { return const_iterator(self(), size()); }

  const Rack& rack(const size_t index) const { return self()->rack(index); }
  Rack&       rack(const size_t index) { return self()->rack(index); }
  Rack*       data() { return self()->data(); }
  const Rack* data() const { return self()->data(); }
  size_t      num_racks() const { return self()->num_racks(); }
  size_t      size() const { return self()->size(); }
  size_t      data_size() const { return self()->data_size(); }

private:
  using Index = detail::BitVecIndex<Derived>;

  Derived*       self() { return static_cast<Derived*>(this); }
  const Derived* self() const { return static_cast<const Derived*>(this); }
};

template<class Derived>
void
gray_code(MultiBitVec<Derived>& value)
{
  typename MultiBitVec<Derived>::Rack s = 0;

  constexpr size_t left_shift = MultiBitVec<Derived>::bits_per_rack - 1;
  for (size_t ri = 0; ri < value.num_racks(); ++ri) {
    const size_t i    = value.num_racks() - ri - 1U;
    auto&        rack = value.rack(i);
    const auto   t    = rack & 1U;
    gray_code(rack);
    rack ^= (s << left_shift);
    s = t;
  }
}

template<class Derived>
void
gray_code_inv(MultiBitVec<Derived>& value)
{
  using Rack = typename MultiBitVec<Derived>::Rack;

  constexpr std::array<Rack, 2> masks{Rack{0}, ~Rack{0}};
  bool                          s = false;

  for (size_t ri = 0; ri < value.num_racks(); ++ri) {
    const size_t i    = value.num_racks() - ri - 1;
    auto&        rack = value.rack(i);
    gray_code_inv(rack);
    rack ^= masks[s];
    s = rack & 1U;
  }
}

} // namespace detail
} // namespace chilbert

#endif