/*  Plzip - A parallel compressor compatible with lzip
    Copyright (C) 2009 Laszlo Ersek.
    Copyright (C) 2009, 2010 Antonio Diaz Diaz.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#define _FILE_OFFSET_BITS 64

#include <algorithm>
#include <cerrno>
#include <climits>
#include <csignal>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <queue>
#include <string>
#include <vector>
#include <inttypes.h>
#include <pthread.h>
#include <unistd.h>
#include <lzlib.h>

#include "plzip.h"

#ifndef LLONG_MAX
#define LLONG_MAX  0x7FFFFFFFFFFFFFFFLL
#endif
#ifndef LLONG_MIN
#define LLONG_MIN  (-LLONG_MAX - 1LL)
#endif
#ifndef ULLONG_MAX
#define ULLONG_MAX 0xFFFFFFFFFFFFFFFFULL
#endif


void xinit( pthread_cond_t * cond, pthread_mutex_t * mutex )
  {
  int errcode = pthread_mutex_init( mutex, 0 );
  if( errcode ) { show_error( "pthread_mutex_init", errcode ); fatal(); }

  errcode = pthread_cond_init( cond, 0 );
  if( errcode ) { show_error( "pthread_cond_init", errcode ); fatal(); }
  }


void xdestroy( pthread_cond_t * cond, pthread_mutex_t * mutex )
  {
  int errcode = pthread_cond_destroy( cond );
  if( errcode ) { show_error( "pthread_cond_destroy", errcode ); fatal(); }

  errcode = pthread_mutex_destroy( mutex );
  if( errcode ) { show_error( "pthread_mutex_destroy", errcode ); fatal(); }
  }


void xlock( pthread_mutex_t * mutex )
  {
  int errcode = pthread_mutex_lock( mutex );
  if( errcode ) { show_error( "pthread_mutex_lock", errcode ); fatal(); }
  }


void xunlock( pthread_mutex_t * mutex )
  {
  int errcode = pthread_mutex_unlock( mutex );
  if( errcode ) { show_error( "pthread_mutex_unlock", errcode ); fatal(); }
  }


void xwait( pthread_cond_t * cond, pthread_mutex_t * mutex )
  {
  int errcode = pthread_cond_wait( cond, mutex );
  if( errcode ) { show_error( "pthread_cond_wait", errcode ); fatal(); }
  }


void xsignal( pthread_cond_t * cond )
  {
  int errcode = pthread_cond_signal( cond );
  if( errcode ) { show_error( "pthread_cond_signal", errcode ); fatal(); }
  }


void xbroadcast( pthread_cond_t * cond )
  {
  int errcode = pthread_cond_broadcast( cond );
  if( errcode ) { show_error( "pthread_cond_broadcast", errcode ); fatal(); }
  }


namespace {

long long in_size = 0;
long long out_size = 0;


struct Packet			// data block with a serial number
  {
  unsigned long long id;	// serial number assigned as received
  uint8_t * data;
  int size;			// number of bytes in data (if any)
  };


class Packet_courier			// moves packets around
  {
public:
  unsigned long icheck_counter;
  unsigned long iwait_counter;
  unsigned long ocheck_counter;
  unsigned long owait_counter;
private:
  unsigned long long receive_id;	// id assigned to next packet received
  unsigned long long deliver_id;	// id of next packet to be delivered
  Slot_tally slot_tally;
  std::queue< Packet * > packet_queue;
  std::vector< Packet * > circular_buffer;
  int num_working;			// Number of workers still running
  const int num_slots;			// max packets in circulation
  pthread_mutex_t imutex;
  pthread_cond_t iav_or_eof;	// input packet available or splitter done
  pthread_mutex_t omutex;
  pthread_cond_t oav_or_exit;	// output packet available or all workers exited
  bool eof;			// splitter done

public:
  Packet_courier( const int num_workers, const int slots )
    : icheck_counter( 0 ), iwait_counter( 0 ),
      ocheck_counter( 0 ), owait_counter( 0 ),
      receive_id( 0 ), deliver_id( 0 ),
      slot_tally( slots ), circular_buffer( slots, (Packet *) 0 ),
      num_working( num_workers ), num_slots( slots ), eof( false )
    { xinit( &iav_or_eof, &imutex ); xinit( &oav_or_exit, &omutex ); }

  ~Packet_courier()
    { xdestroy( &iav_or_eof, &imutex ); xdestroy( &oav_or_exit, &omutex ); }

  // make a packet with data received from splitter
  void receive_packet( uint8_t * const data, const int size )
    {
    Packet * ipacket = new Packet;
    ipacket->id = receive_id++;
    ipacket->data = data;
    ipacket->size = size;
    slot_tally.get_slot();		// wait for a free slot
    xlock( &imutex );
    packet_queue.push( ipacket );
    xsignal( &iav_or_eof );
    xunlock( &imutex );
    }

  // distribute a packet to a worker
  Packet * distribute_packet()
    {
    Packet * ipacket = 0;
    xlock( &imutex );
    ++icheck_counter;
    while( packet_queue.empty() && !eof )
      {
      ++iwait_counter;
      xwait( &iav_or_eof, &imutex );
      ++icheck_counter;
      }
    if( !packet_queue.empty() )
      {
      ipacket = packet_queue.front();
      packet_queue.pop();
      }
    xunlock( &imutex );
    if( ipacket == 0 )
      {
      // Notify muxer when last worker exits
      xlock( &omutex );
      if( --num_working == 0 )
        xsignal( &oav_or_exit );
      xunlock( &omutex );
      }
    return ipacket;
    }

  // collect a packet from a worker
  void collect_packet( Packet * const opacket )
    {
    xlock( &omutex );
    // id collision shouldn't happen
    if( circular_buffer[opacket->id%num_slots] != 0 )
      internal_error( "id collision in collect_packet" );
    // Merge packet into circular buffer
    circular_buffer[opacket->id%num_slots] = opacket;
    if( opacket->id == deliver_id ) xsignal( &oav_or_exit );
    xunlock( &omutex );
    }

  // deliver a packet to muxer
  Packet * deliver_packet()
    {
    xlock( &omutex );
    ++ocheck_counter;
    while( circular_buffer[deliver_id%num_slots] == 0 && num_working > 0 )
      {
      ++owait_counter;
      xwait( &oav_or_exit, &omutex );
      ++ocheck_counter;
      }
    Packet * opacket = circular_buffer[deliver_id%num_slots];
    circular_buffer[deliver_id%num_slots] = 0;
    ++deliver_id;
    xunlock( &omutex );
    if( opacket != 0 )
      slot_tally.leave_slot();		// return a slot to the tally
    return opacket;
    }

  void finish()			// splitter has no more packets to send
    {
    xlock( &imutex );
    eof = true;
    xbroadcast( &iav_or_eof );
    xunlock( &imutex );
    }

  bool finished()		// all packets delivered to muxer
    {
    if( !slot_tally.all_free() || !eof || !packet_queue.empty() ||
        num_working != 0 ) return false;
    for( int i = 0; i < num_slots; ++i )
      if( circular_buffer[i] != 0 ) return false;
    return true;
    }

  const Slot_tally & tally() const { return slot_tally; }
  };


struct Splitter_arg
  {
  Packet_courier * courier;
  const Pretty_print * pp;
  int infd;
  int data_size;
  };


       // split data from input file into chunks and pass them to
       // courier for packaging and distribution to workers.
extern "C" void * csplitter( void * arg )
  {
  const Splitter_arg & tmp = *(Splitter_arg *)arg;
  Packet_courier & courier = *tmp.courier;
  const Pretty_print & pp = *tmp.pp;
  const int infd = tmp.infd;
  const int data_size = tmp.data_size;

  for( bool first_post = true; ; first_post = false )
    {
    uint8_t * data = new( std::nothrow ) uint8_t[data_size];
    if( data == 0 ) { pp( "not enough memory" ); fatal(); }
    const int size = readblock( infd, data, data_size );
    if( size != data_size && errno )
      { pp(); show_error( "read error", errno ); fatal(); }

    if( size > 0 || first_post )	// first packet can be empty
      {
      in_size += size;
      courier.receive_packet( data, size );
      }
    else
      {
      delete[] data;
      courier.finish();			// no more packets to send
      break;
      }
    }
  return 0;
  }


struct Worker_arg
  {
  Packet_courier * courier;
  const Pretty_print * pp;
  int dictionary_size;
  int match_len_limit;
  };


       // get packets from courier, replace their contents, and return
       // them to courier.
extern "C" void * cworker( void * arg )
  {
  const Worker_arg & tmp = *(Worker_arg *)arg;
  Packet_courier & courier = *tmp.courier;
  const Pretty_print & pp = *tmp.pp;
  const int dictionary_size = tmp.dictionary_size;
  const int match_len_limit = tmp.match_len_limit;

  while( true )
    {
    Packet * packet = courier.distribute_packet();
    if( packet == 0 ) break;		// no more packets to process

    const int compr_size = 42 + packet->size + ( ( packet->size + 7 ) / 8 );
    uint8_t * const new_data = new( std::nothrow ) uint8_t[compr_size];
    if( new_data == 0 ) { pp( "not enough memory" ); fatal(); }
    const int dict_size = std::max( LZ_min_dictionary_size(),
                                    std::min( dictionary_size, packet->size ) );
    LZ_Encoder * const encoder =
      LZ_compress_open( dict_size, match_len_limit, LLONG_MAX );
    if( !encoder || LZ_compress_errno( encoder ) != LZ_ok )
      {
      if( !encoder || LZ_compress_errno( encoder ) == LZ_mem_error )
        pp( "not enough memory. Try a smaller dictionary size" );
      else
        internal_error( "invalid argument to encoder" );
      fatal();
      }

    int written = 0;
    int new_size = 0;
    while( true )
      {
      if( LZ_compress_write_size( encoder ) > 0 )
        {
        if( written < packet->size )
          {
          const int wr = LZ_compress_write( encoder, packet->data + written,
                                            packet->size - written );
          if( wr < 0 ) internal_error( "library error (LZ_compress_write)" );
          written += wr;
          }
        if( written >= packet->size ) LZ_compress_finish( encoder );
        }
      const int rd = LZ_compress_read( encoder, new_data + new_size,
                                       compr_size - new_size );
      if( rd < 0 )
        {
        pp();
        if( verbosity >= 0 )
          std::fprintf( stderr, "LZ_compress_read error: %s.\n",
                        LZ_strerror( LZ_compress_errno( encoder ) ) );
        fatal();
        }
      new_size += rd;
      if( new_size > compr_size )
        internal_error( "packet size exceeded in worker" );
      if( LZ_compress_finished( encoder ) == 1 ) break;
      }

    if( LZ_compress_close( encoder ) < 0 )
      { pp( "LZ_compress_close failed" ); fatal(); }

    delete[] packet->data;
    packet->data = new_data;
    packet->size = new_size;
    courier.collect_packet( packet );
    }
  return 0;
  }


     // get from courier the processed and sorted packets, and write
     // their contents to the output file.
void muxer( Packet_courier & courier, const Pretty_print & pp, const int outfd )
  {
  while( true )
    {
    Packet * opacket = courier.deliver_packet();
    if( opacket == 0 ) break;	// queue is empty. all workers exited

    out_size += opacket->size;

    if( outfd >= 0 )
      {
      const int wr = writeblock( outfd, opacket->data, opacket->size );
      if( wr != opacket->size )
        { pp(); show_error( "write error", errno ); fatal(); }
      }
    delete[] opacket->data;
    delete opacket;
    }
  }

} // end namespace


    // init the courier, then start the splitter and the workers and
    // call the muxer.
int compress( const int data_size, const int dictionary_size,
              const int match_len_limit, const int num_workers,
              const int num_slots, const int infd, const int outfd,
              const Pretty_print & pp, const int debug_level )
  {
  in_size = 0;
  out_size = 0;
  Packet_courier courier( num_workers, num_slots );

  Splitter_arg splitter_arg;
  splitter_arg.courier = &courier;
  splitter_arg.pp = &pp;
  splitter_arg.infd = infd;
  splitter_arg.data_size = data_size;

  pthread_t splitter_thread;
  int errcode = pthread_create( &splitter_thread, 0, csplitter, &splitter_arg );
  if( errcode )
    { show_error( "can't create splitter thread", errcode ); fatal(); }

  Worker_arg worker_arg;
  worker_arg.courier = &courier;
  worker_arg.pp = &pp;
  worker_arg.dictionary_size = dictionary_size;
  worker_arg.match_len_limit = match_len_limit;

  pthread_t * worker_threads = new( std::nothrow ) pthread_t[num_workers];
  if( worker_threads == 0 )
    { pp( "not enough memory" ); fatal(); }
  for( int i = 0; i < num_workers; ++i )
    {
    errcode = pthread_create( &worker_threads[i], 0, cworker, &worker_arg );
    if( errcode )
      { show_error( "can't create worker threads", errcode ); fatal(); }
    }

  muxer( courier, pp, outfd );

  for( int i = num_workers - 1; i >= 0; --i )
    {
    errcode = pthread_join( worker_threads[i], 0 );
    if( errcode )
      { show_error( "can't join worker threads", errcode ); fatal(); }
    }
  delete[] worker_threads; worker_threads = 0;

  errcode = pthread_join( splitter_thread, 0 );
  if( errcode )
    { show_error( "can't join splitter thread", errcode ); fatal(); }

  if( verbosity >= 1 )
    {
    if( in_size <= 0 || out_size <= 0 )
      std::fprintf( stderr, "no data compressed.\n" );
    else
      std::fprintf( stderr, "%6.3f:1, %6.3f bits/byte, "
                            "%5.2f%% saved, %lld in, %lld out.\n",
                    (double)in_size / out_size,
                    ( 8.0 * out_size ) / in_size,
                    100.0 * ( 1.0 - ( (double)out_size / in_size ) ),
                    in_size, out_size );
    }

  if( debug_level & 1 )
    std::fprintf( stderr,
      "splitter tried to send a packet           %8lu times\n"
      "splitter had to wait                      %8lu times\n"
      "any worker tried to consume from splitter %8lu times\n"
      "any worker had to wait                    %8lu times\n"
      "muxer tried to consume from workers       %8lu times\n"
      "muxer had to wait                         %8lu times\n",
      courier.tally().check_counter,
      courier.tally().wait_counter,
      courier.icheck_counter,
      courier.iwait_counter,
      courier.ocheck_counter,
      courier.owait_counter );

  if( !courier.finished() ) internal_error( "courier not finished" );
  return 0;
  }