/*-------------------------------------------------------------------
 * Time domain harmonic scaling by
 * Pointer Inteval Controled OverLap and ADD (PICOLA) Method
 *		C version by IKEDA Mikio
 *		original argolithm is developed by MORITA Naotaka
 *		about detail, see original paper.
 *-------------------------------------------------------------------
 * Usage
 *  PICOLA <source signal>
 *         <companded (destination) signal>
 *	   <compansion ratio>
 *         <window length>
 *         <pitch minimum>
 *         <pitch maximum>
 * Last three arguments can be abbriviated.
 *
 * Note
 *   This software can handel only singend 16bits raw audio data.
 *   The procedure is NOT identical that of 
 *   http://keizai.yokkaichi-u.ac.jp/%7Eikeda/research/picola.html
 *-------------------------------------------------------------------
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

#define L_BUF 2048

void ola(int, short [], short []);
void ola_compress(int, short []);
int  pic(short [], int, int, FILE *, FILE *);
int  amdfpitch(int, int, int, short []);
int  covpitch(int, int, int, short []);

/*----*/

int
main(int argc, char *argv[])
{
  short is[L_BUF];  /* input signal buffer */
  double rate;     /* compansion rate
		    * case of less than 1.0 compression,
		    * case of greater than 1.0 expansion
		    */

  double rcomp;	       /* internal modified compansion ratio */
  double sl;
  double err = 0.0;	/* compansion rate error estimate */
  int pitmin = 32;	/* minimal pitch period (250Hz for 8kHz) */
  int pitmax = 128;	/* maximal pitch period (62.5Hz for 8kHz)*/
  int pitch;		/* detected pitch period */
  int length = 256;     /* at least pitmax*2 is required */
  int nread;		/* number of read samples (from file) */
  int n_rest;

  int lcp;		/* number of copy samples */
  int i;		/* loop counter */
  int point;
  int total;
  /* int lproc = 0;*/	/* processed speech samples */

  char  srcfile[256], dstfile[256];
  FILE  *srcfd, *dstfd;
  /*
   *--------------- 
   * get arguments from command line or stdin
   */
  if (argc >= 2) strncpy(srcfile, argv[1], 255);
  else{
    printf("source signal file    = ");
    scanf("%255s", srcfile);
  }

  if (argc >= 3) strncpy(dstfile, argv[2], 255);
  else{
    printf("companded signal file = ");
    scanf("%255s", dstfile);
  }

  if (argc >= 4) rate = atof(argv[3]);
  else{
    printf("compansion rate       = ");
    scanf("%lf", &rate);
  }
  printf("%10.4f\n", rate);

  if (argc >= 5) length = atoi(argv[4]);
  if (argc >= 6) pitmin = atoi(argv[5]);
  if (argc >= 7) pitmax = atoi(argv[6]);
  /*
   *-------------- error check and initialize ---------------------
   */
  if (rate <= 0.0 || rate == 1.0) {
    printf("illeagal compansion rate !!\n");
    return 0;
  }
  if (pitmin < 16) {
    printf("pitch detection minimum threshold modified !!\n");
    pitmin = 16;
  }
  if (pitmax > L_BUF/2) {
    printf("pitch detection maximum threshold modified !!\n");
    pitmax = L_BUF/2;
  }
  if (length + pitmax < pitmax*2) {
    length = pitmax;
    printf("frame length have modified !!\n");
  }
  if (length + pitmax > L_BUF) {
    printf("Error from %s: too long frame length !!\n", argv[0]);
    return 0;
  }

  if (rate > 1.0) {
    rcomp = 1.0  / (rate - 1.0); /* */
  }else if (rate > 0) {
    rcomp = rate / (1.0 - rate);
  }else {
    printf("Error from %s: illeagal compansion rate!\n", argv[0]);
    return 0;
  }

  srcfd = fopen(srcfile,"r");
  dstfd = fopen(dstfile,"w");

  /*
   *------------------- body ---------------
   */

  total = length + pitmax;
  nread = fread(is, sizeof(*is), total, srcfd);
  n_rest = total - nread;

  while (n_rest == 0) {

    /*---- pitch extraction ----*/

    pitch = amdfpitch(pitmin, pitmax, length, is);

    /*---- compensate compansion rate ----*/

    sl = (double)pitch * rcomp;
    lcp = sl;
    err += lcp - sl;

    if (err >= 0.5) {	
      --lcp;
      err -= 1.0;
    }else if (err <= -0.5) {
      ++lcp;
      err += 1.0;
    }
    if (rate < 1.0) {
      ola(pitch, is, is+pitch);
      point = total - pitch;
      for (i = pitch; i < point; ++i) {  /* remove Tp length wavelet */ 
	is[i] = is[i+pitch];
      }
      if (pitch != (nread = fread(is+point, sizeof(*is), pitch, srcfd))) {
	n_rest = total - pitch + nread;
	break;
      }
    }else{
      fwrite(is, sizeof(*is), pitch, dstfd); /* add Tp length wavelet */
      ola(pitch, is+pitch, is);
    }
    n_rest = pic(is, lcp, total, srcfd, dstfd);
  }
  fclose(srcfd);

  /* flush left data within input buffer */
  if (n_rest > 0) {
    fwrite(is, sizeof(*is), n_rest, dstfd);
  }
  fclose(dstfd);
  return 1;
}

/*-------------------------------------------------------------------
 * picola OverLap and add stage 
 */

void
ola(int tp, short is1[], short is2[])
{
  int i;
  double ss, w, step;
  step = 1.0/tp;
  for (i = 0; i < tp; ++i) {
    w = step * i;
    ss = (double)is1[i] * (1.0-w) + (double)is2[i] * w;
    is2[i] = ss;
  }
  return;
}

/*-------------------------------------------------------------------
 * Picola pointer interval control stage (common part)
 * Read & Write lcp length samples
 *
 *               expansion  compression
 * Read           lcp        Tp+lcp
 * Write          lcp+Tp     lcp
 *
 * already read   0          Tp
 * already wrote  Tp         0
 *
 */

int
pic(short is[], int lcp, int length, FILE *srcfd, FILE *dstfd)
{
  int t_read, t_write;
  int point, i, lw, nread; 

  t_read  = 0;  /* number of total read  samples */
  t_write = 0;  /* number of total wrote samples */

  /* output 'lcp' samples using limited length buffer */

  for (lw = lcp; lw >= length; lw -= length) {
    fwrite(is, sizeof(*is), length, dstfd);
    nread   = fread(is, sizeof(*is), length, srcfd);
    t_write += length;
    t_read  += nread;
    if (nread != length) {
      return t_read - t_write;
    }
  }
  fwrite(is, sizeof(*is), lw, dstfd);
  t_write += lw;           /* fool proof */
  point = length - lw;
  for (i = 0; i < point; ++i) {
    is[i] = is[i+lw];
  }
  nread = fread(is + point, sizeof(*is), lw, srcfd);
  t_read += nread;
  return lcp - t_read;  /* */
}

/*-------------------------------------------------------------------
 * periodicity extraction using covariance method 
 * (Slow)
 */


int covpitch(int pitmin, int pitmax, int length, short is[])
{
  int i, j, pitch;
  double covst, covs0t, covmax = 0.0, s;

  pitch = pitmin;
  for (i = pitmin; i <= pitmax; ++i) {
    covst = 0.0;
    covs0t = 0.0;
    for (j = 0; j < length; j++) {
      s = (double)is[i+j];
      covs0t += s * s;
      covst += (double)is[j] * s;
    }
    covst = covst / sqrt(covs0t);
    if (covst >= covmax) {
      covmax = covst;
      pitch = i;
    }
  }
  return(pitch);
}

/*-------------------------------------------------------------------
 * periodicity extration using Averaged Mean Difference Function (AMDF)
 * 
 * The extracted value is NOT pitch period in precise meaning.
 *
 */

int amdfpitch(int pitmin, int pitmax, int length, short is[])
{
  int i, j, diff, acc, accmin, pitch;

  pitch = pitmin;
  accmin = 0;
  for (j = 0; j < length; ++j) {
    diff = is[j+pitmin] - is[j];
    if (diff > 0) {
      accmin += diff;
    }else{
      accmin -= diff;
    }
  }
  for (i = pitmin + 1; i <= pitmax; ++i) {
    acc = 0;
    for (j = 0; j < length; ++j) {
      diff = is[i+j] - is[j];
      if (diff > 0) {
	acc += diff;
      }else{
	acc -= diff;
      }
    }
    if (acc < accmin) {
      accmin = acc;
      pitch = i;
    }
  }
  return pitch;
}

