#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include "mem.h"
#include "wavelet.h"
#include "rle.h"

#define printf(args...)

#define GRAY_CODES 1

#if defined(GRAY_CODES)
static inline
uint16_t binary_to_gray (uint16_t x) { return  x ^ (x >> 1); }

static inline
uint16_t gray_to_binary (uint16_t x)
{ int i; for (i=1; i<16; i+=i) x ^= x >> i; return x; }
#endif


static inline
void encode_coeff (ENTROPY_CODER significand_bitstream [],
                   ENTROPY_CODER insignificand_bitstream [],
                   TYPE coeff)
{
   int sign = (coeff >> (8*sizeof(TYPE)-1)) & 1;
#if defined(GRAY_CODES)
   TYPE significance = binary_to_gray(coeff);
#else
   static TYPE mask [2] = { 0, ~0 };
   TYPE significance = coeff ^ mask[sign];
#endif
   int i = TYPE_BITS;

   do {
      i--;
      OUTPUT_BIT(&significand_bitstream[i], (significance >> i) & 1);
   } while (!((significance >> i) & 1) && i > 0);

   OUTPUT_BIT(&significand_bitstream[i], sign);

   while (--i >= 0)
      OUTPUT_BIT(&insignificand_bitstream[i], (significance >> i) & 1);
}



static inline
TYPE decode_coeff (ENTROPY_CODER significand_bitstream [],
                   ENTROPY_CODER insignificand_bitstream [])
{
#if !defined(GRAY_CODES)
   static TYPE mask [2] = { 0, ~0 };
#endif
   TYPE significance = 0;
   int sign;
   int i = TYPE_BITS;

   do {
      i--;
      significance |= INPUT_BIT(&significand_bitstream[i]) << i;
/*    if (ENTROPY_CODER_EOS(&significand_bitstream[i])) */
/*       return 0; */
   } while (!significance && i > 0);

   sign = INPUT_BIT(&significand_bitstream[i]);
/* if (ENTROPY_CODER_EOS(&significand_bitstream[i])) */
/*    return 0; */

   while (--i >= 0)
      significance |= INPUT_BIT(&insignificand_bitstream[i]) << i;

#if defined(GRAY_CODES)
   significance |= sign << (8*sizeof(TYPE)-1);
   return gray_to_binary(significance);
#else
   return (significance ^ mask[sign]);
#endif
}


static inline
uint32_t skip_0coeffs (Wavelet3DBuf* buf,
                       ENTROPY_CODER s_stream [],
                       ENTROPY_CODER i_stream [],
                       uint32_t limit)
{
   int i;
   uint32_t skip = limit;

   for (i=0; i<TYPE_BITS; i++) {
      if (ENTROPY_CODER_SYMBOL(&s_stream[i]) != 0) {
         return 0;
      } else {
         uint32_t runlength = ENTROPY_CODER_RUNLENGTH(&s_stream[i]);
         if (i==0)
            runlength /= 2;     /* sign bits are in this bitplane ... */
         if (skip > runlength)
            skip = runlength;
         if (skip <= 2)
            return 0;
      }
   }

   ENTROPY_CODER_SKIP(&s_stream[0], 2*skip); /* kill sign+significance bits */

   for (i=1; i<TYPE_BITS; i++)
      ENTROPY_CODER_SKIP(&s_stream[i], skip);

   return skip;
}



#if 1
static inline
void encode_quadrant (const Wavelet3DBuf* buf,
                      int level, int quadrant, uint32_t w, uint32_t h, uint32_t f,
                      ENTROPY_CODER significand_bitstream [],
                      ENTROPY_CODER insignificand_bitstream [])
{
   uint32_t x, y, z;

   for (z=0; z<f; z++) {
      for (y=0; y<h; y++) {
         for (x=0; x<w; x++) {
            unsigned int index = buf->offset [level] [quadrant]
                                   + z * buf->width * buf->height
                                   + y * buf->width + x;

            encode_coeff (significand_bitstream, insignificand_bitstream,
                          buf->data [index]);
         }
      }
   }
}


static
void encode_coefficients (const Wavelet3DBuf* buf,
                          ENTROPY_CODER s_stream [],
                          ENTROPY_CODER i_stream [])
{
   int level;

   encode_coeff (s_stream, i_stream, buf->data[0]);

   for (level=0; level<buf->scales-1; level++) {
      uint32_t w, h, f, w1, h1, f1;

      w = buf->w [level];
      h = buf->h [level];
      f = buf->f [level];
      w1 = buf->w [level+1] - w;
      h1 = buf->h [level+1] - h;
      f1 = buf->f [level+1] - f;

      if (w1 > 0) encode_quadrant(buf,level,1,w1,h,f,s_stream,i_stream);
      if (h1 > 0) encode_quadrant(buf,level,2,w,h1,f,s_stream,i_stream);
      if (f1 > 0) encode_quadrant(buf,level,3,w,h,f1,s_stream,i_stream);
      if (w1 > 0 && h1 > 0) encode_quadrant(buf,level,4,w1,h1,f,s_stream,i_stream);
      if (w1 > 0 && f1 > 0) encode_quadrant(buf,level,5,w1,h,f1,s_stream,i_stream);
      if (h1 > 0 && f1 > 0) encode_quadrant(buf,level,6,w,h1,f1,s_stream,i_stream);
      if (h1 > 0 && f1 > 0 && f1 > 0)
         encode_quadrant (buf,level,7,w1,h1,f1,s_stream,i_stream);
   }
}


static inline
void decode_quadrant (Wavelet3DBuf* buf,
                      int level, int quadrant, uint32_t w, uint32_t h, uint32_t f,
                      ENTROPY_CODER s_stream [],
                      ENTROPY_CODER i_stream [])
{
   uint32_t x, y, z;

   z = 0;
   do {
      y = 0;
      do {
         x = 0;
         do {
            uint32_t skip;
            uint32_t index = buf->offset [level] [quadrant]
                               + z * buf->width * buf->height
                               + y * buf->width + x;

            buf->data [index] = decode_coeff (s_stream, i_stream);

            skip = skip_0coeffs (buf, s_stream, i_stream,
                                 (w-x-1)+(h-y-1)*w+(f-z-1)*w*h);
            if (skip > 0) {
               x += skip;
               while (x >= w) {
                  y++; x -= w;
                  while (y >= h) {
                     z++; y -= h;
                     if (z >= f)
                        return;
                  }
               }
            }
            x++;
         } while (x < w);
         y++;
      } while (y < h);
      z++;
   } while (z < f);
}


static
void decode_coefficients (Wavelet3DBuf* buf,
                          ENTROPY_CODER s_stream [],
                          ENTROPY_CODER i_stream [])
{
   int level;

   buf->data[0] = decode_coeff (s_stream, i_stream);

   for (level=0; level<buf->scales-1; level++) {
      uint32_t w, h, f, w1, h1, f1;

      w = buf->w [level];
      h = buf->h [level];
      f = buf->f [level];
      w1 = buf->w [level+1] - w;
      h1 = buf->h [level+1] - h;
      f1 = buf->f [level+1] - f;

      if (w1 > 0) decode_quadrant(buf,level,1,w1,h,f,s_stream,i_stream);
      if (h1 > 0) decode_quadrant(buf,level,2,w,h1,f,s_stream,i_stream);
      if (f1 > 0) decode_quadrant(buf,level,3,w,h,f1,s_stream,i_stream);
      if (w1 > 0 && h1 > 0) decode_quadrant(buf,level,4,w1,h1,f,s_stream,i_stream);
      if (w1 > 0 && f1 > 0) decode_quadrant(buf,level,5,w1,h,f1,s_stream,i_stream);
      if (h1 > 0 && f1 > 0) decode_quadrant(buf,level,6,w,h1,f1,s_stream,i_stream);
      if (h1 > 0 && f1 > 0 && f1 > 0)
         decode_quadrant (buf,level,7,w1,h1,f1,s_stream,i_stream);
   }
}
#else

static
void encode_coefficients (const Wavelet3DBuf* buf,
                          ENTROPY_CODER s_stream [],
                          ENTROPY_CODER i_stream [])
{
   uint32_t i;

   for (i=0; i<buf->width*buf->height*buf->frames; i++)
      encode_coeff(s_stream, i_stream, buf->data[i]);
}




static
void decode_coefficients (Wavelet3DBuf* buf,
                          ENTROPY_CODER s_stream [],
                          ENTROPY_CODER i_stream [])
{
   uint32_t i;

   for (i=0; i<buf->width*buf->height*buf->frames; i++) {
      uint32_t skip;

      buf->data[i] = decode_coeff(s_stream, i_stream);

      skip = skip_0coeffs (buf, s_stream, i_stream,
                           buf->width*buf->height*buf->frames - i);
      i += skip;
   }
}
#endif



static
uint32_t setup_limittabs (ENTROPY_CODER significand_bitstream [],
                          ENTROPY_CODER insignificand_bitstream [],
                          uint32_t significand_limittab [],
                          uint32_t insignificand_limittab [],
                          uint32_t limit)
{
   uint32_t significand_limit;
   uint32_t insignificand_limit;
   uint32_t byte_count;
   int i;

   assert (limit > 2 * TYPE_BITS * sizeof(uint32_t)); /* limit too small */

   printf ("%s: limit == %u\n", __FUNCTION__, limit);
   byte_count = 2 * TYPE_BITS * sizeof(uint32_t); /* 2 binary coded limittabs */
   limit -= byte_count;
   printf ("%s: rem. limit == %u\n", __FUNCTION__, limit);

   significand_limit = limit * 7 / 8;
   insignificand_limit = limit - significand_limit;

   printf ("%s: limit == %u\n", __FUNCTION__, limit);
   printf ("significand limit == %u\n", significand_limit);
   printf ("insignificand limit == %u\n", insignificand_limit);

   for (i=TYPE_BITS-1; i>=0; i--) {
      uint32_t s_bytes, i_bytes;

      if (i > 0) {
         significand_limittab[i] = (significand_limit + 1) / 2;
         insignificand_limittab[i] = (insignificand_limit + 1) / 2;
      } else {
         significand_limittab[0] = significand_limit;
         insignificand_limittab[0] = insignificand_limit;
      }

      s_bytes = ENTROPY_ENCODER_FLUSH(&significand_bitstream[i]);
      i_bytes = ENTROPY_ENCODER_FLUSH(&insignificand_bitstream[i]);

      if (s_bytes < significand_limittab[i])
         significand_limittab[i] = s_bytes;

      if (i_bytes < insignificand_limittab[i])
         insignificand_limittab[i] = i_bytes;

      byte_count += significand_limittab[i];
      byte_count += insignificand_limittab[i];

      printf ("insignificand_limittab[%i]  == %u / %u\n", 
              i, insignificand_limittab[i], i_bytes);
      printf ("  significand_limittab[%i]  == %u / %u\n",
              i, significand_limittab[i], s_bytes);

      significand_limit -= significand_limittab[i];
      insignificand_limit -= insignificand_limittab[i];
   }

   printf ("byte_count == %u\n", byte_count);

   return byte_count;
}


/**
 *  write 'em binary for now, should be easy to compress ...
 */
static
uint8_t* write_limittabs (uint8_t *bitstream,
                          uint32_t significand_limittab [],
                          uint32_t insignificand_limittab [])
{
   int i;

   for (i=0; i<TYPE_BITS; i++) {
      *(uint32_t*) bitstream = significand_limittab[i];
      bitstream += 4;
   }

   for (i=0; i<TYPE_BITS; i++) {
      *(uint32_t*) bitstream = insignificand_limittab[i];
      bitstream += 4;
   }

   return bitstream;
}


static
uint8_t* read_limittabs (uint8_t *bitstream,
                         uint32_t significand_limittab [],
                         uint32_t insignificand_limittab [])
{
   int i;

   for (i=0; i<TYPE_BITS; i++) {
      significand_limittab[i] = *(uint32_t*) bitstream;
      printf ("significand_limittab[%i]  == %u\n", i, significand_limittab[i]);
      bitstream += 4;
   }

   for (i=0; i<TYPE_BITS; i++) {
      insignificand_limittab[i] = *(uint32_t*) bitstream;
      printf ("insignificand_limittab[%i]  == %u\n", i, insignificand_limittab[i]);
      bitstream += 4;
   }
 
   return bitstream;
}


/**
 *  concatenate entropy coder bitstreams
 */
static
void merge_bitstreams (uint8_t *bitstream,
                       ENTROPY_CODER significand_bitstream [],
                       ENTROPY_CODER insignificand_bitstream [],
                       uint32_t significand_limittab [],
                       uint32_t insignificand_limittab [])
{
   int i;

   for (i=TYPE_BITS-1; i>=0; i--) {
      memcpy (bitstream,
              ENTROPY_CODER_BITSTREAM(&significand_bitstream[i]),
              significand_limittab[i]);

      bitstream += significand_limittab[i];
   }

   for (i=TYPE_BITS-1; i>=0; i--) {
      memcpy (bitstream,
              ENTROPY_CODER_BITSTREAM(&insignificand_bitstream[i]),
              insignificand_limittab[i]);

      bitstream += insignificand_limittab[i];
   }
}


static
void split_bitstreams (uint8_t *bitstream,
                       ENTROPY_CODER significand_bitstream [],
                       ENTROPY_CODER insignificand_bitstream [],
                       uint32_t significand_limittab [],
                       uint32_t insignificand_limittab [])
{
   uint32_t byte_count;
   int i;

   for (i=TYPE_BITS-1; i>=0; i--) {
      byte_count = significand_limittab[i];
      ENTROPY_DECODER_INIT(&significand_bitstream[i], bitstream, byte_count);
      bitstream += byte_count;
   }

   for (i=TYPE_BITS-1; i>=0; i--) {
      byte_count = insignificand_limittab[i];
      ENTROPY_DECODER_INIT(&insignificand_bitstream[i], bitstream, byte_count);
      bitstream += byte_count;
   }
}


int wavelet_3d_buf_encode_coeff (const Wavelet3DBuf* buf,
                                 uint8_t *bitstream,
                                 uint32_t limit)
{
   ENTROPY_CODER significand_bitstream [TYPE_BITS];
   ENTROPY_CODER insignificand_bitstream [TYPE_BITS];
   uint32_t significand_limittab [TYPE_BITS];
   uint32_t insignificand_limittab [TYPE_BITS];
   uint32_t byte_count;
   int i;

   for (i=0; i<TYPE_BITS; i++)
      ENTROPY_ENCODER_INIT(&significand_bitstream[i], limit);
   for (i=0; i<TYPE_BITS; i++)
      ENTROPY_ENCODER_INIT(&insignificand_bitstream[i], limit);

   encode_coefficients (buf, significand_bitstream, insignificand_bitstream);

   byte_count = setup_limittabs (significand_bitstream, insignificand_bitstream,
                                 significand_limittab, insignificand_limittab,
                                 limit);

   bitstream = write_limittabs (bitstream,
                                significand_limittab, insignificand_limittab);

   merge_bitstreams (bitstream, significand_bitstream, insignificand_bitstream,
                     significand_limittab, insignificand_limittab);

   for (i=0; i<TYPE_BITS; i++) {
      ENTROPY_ENCODER_DONE(&significand_bitstream[i]);
      ENTROPY_ENCODER_DONE(&insignificand_bitstream[i]);
   }

   return byte_count;
}


void wavelet_3d_buf_decode_coeff (Wavelet3DBuf* buf,
                                  uint8_t *bitstream,
                                  uint32_t byte_count)
{
   ENTROPY_CODER significand_bitstream [TYPE_BITS];
   ENTROPY_CODER insignificand_bitstream [TYPE_BITS];
   uint32_t significand_limittab [TYPE_BITS];
   uint32_t insignificand_limittab [TYPE_BITS];
   int i;

   memset (buf->data, 0,
           buf->width * buf->height * buf->frames * sizeof(TYPE));

   bitstream = read_limittabs (bitstream,
                               significand_limittab, insignificand_limittab);

   split_bitstreams (bitstream, significand_bitstream, insignificand_bitstream,
                               significand_limittab, insignificand_limittab);

   decode_coefficients (buf, significand_bitstream, insignificand_bitstream);

   for (i=0; i<TYPE_BITS; i++) {
      ENTROPY_DECODER_DONE(&significand_bitstream[i]);
      ENTROPY_DECODER_DONE(&insignificand_bitstream[i]);
   }
}