
#include <math.h>
#include <iostream.h>

#include "fxt.h"
#include "fxtaux.h"
#include "fxtdefs.h"  // SUMDIFF, CSQR, CMULT
#include "auxgen.h"
#include "permute.h"

#define  A(c)  act+=c
#define  M(c)  mct+=c

void
prfht(double *f, ulong ldn, int zp=0)
// prints code for a length-2**ldn fht
{
  if (ldn<=1)
  {
    if (ldn==1)
    {
      if ( zp )  cout<<"f[1]=f[0];" << endl;
      else       cout<<"SUMDIFF2(f[0], f[1]);"<<endl;  A(2);
    }
    return;
  }

  ulong n = (1<<ldn);

  cout<<endl;
  const double *fn = f + n;
  double *fi, *gi;

  ulong ldk = ldn&1;

  cout<<"{ // start initial loop"<<endl;
  if (ldk==0)  /* ldn is multiple of 2  => n is a power of 4 */
  {
    for (fi=f; fi<fn; fi+=4)
    {
      cout<<"{ // fi = "<< (int)(fi-f) <<endl;
      if ( zp )
      {
        cout<<"SUMDIFF2("<<idxf(fi,0)<<", "<<idxf(fi,2)<<");" <<endl; A(2);
        cout<<idxf(fi,3)<<" = "<<idxf(fi,2)<<";" <<endl;
        cout<<idxf(fi,1)<<" = "<<idxf(fi,0)<<";" <<endl;
      }
      else
      {
        cout<<"Type f0, f1, f2, f3;"<<endl;

        cout<<"SUMDIFF4("<<idxf(fi,0)<<", "<<idxf(fi,1)<<", f0, f1);" <<endl; A(2);
        cout<<"SUMDIFF4("<<idxf(fi,2)<<", "<<idxf(fi,3)<<", f2, f3);" <<endl; A(2);

        cout<<"SUMDIFF4(f0, f2, "<<idxf(fi,0)<<", "<<idxf(fi,2)<<");" <<endl; A(2);
        cout<<"SUMDIFF4(f1, f3, "<<idxf(fi,1)<<", "<<idxf(fi,3)<<");" <<endl; A(2);
      }
      cout << "}" << endl;
      ldk = 2;
    }
  }
  else      /* ldk==1,  ldn is no multiple of 2  => n is no power of 4 */
  {
    for (fi=f,gi=fi+1; fi<fn; fi+=8,gi+=8)
    {
      cout<<"{ // fi = "<< (int)(fi-f) << "  gi = "<< (int)(gi-f) << endl;

      cout<<"Type g0, f0, f1, g1;"<<endl;

      cout<<"SUMDIFF4("<<idxf(fi,0)<<", "<<idxf(gi,0)<<", f0, g0);" <<endl; A(2);
      cout<<"SUMDIFF4("<<idxf(fi,2)<<", "<<idxf(gi,2)<<", f1, g1);" <<endl; A(2);
      cout<<"SUMDIFF2(f0, f1);" <<endl; A(2);
      cout<<"SUMDIFF2(g0, g1);" <<endl; A(2);

      cout<<"Type s1, c1, s2, c2;"<<endl;
      cout<<"SUMDIFF4("<<idxf(fi,4)<<", "<<idxf(gi,4)<<", s1, c1);" <<endl; A(2);
      cout<<"SUMDIFF4("<<idxf(fi,6)<<", "<<idxf(gi,6)<<", s2, c2);" <<endl; A(2);

      cout<<"SUMDIFF2(s1, s2);" <<endl; A(2);

      cout<<"SUMDIFF4(f0, s1, "<<idxf(fi,0)<<", "<<idxf(fi,4)<<");" <<endl; A(2);
      cout<<"SUMDIFF4(f1, s2, "<<idxf(fi,2)<<", "<<idxf(fi,6)<<");" <<endl; A(2);

      cout<<"c1 *= SQRT2;"<<endl; M(1);
      cout<<"c2 *= SQRT2;"<<endl; M(1);

      cout<<"SUMDIFF4(g0, c1, "<<idxf(gi,0)<<", "<<idxf(gi,4)<<");" <<endl; A(2);
      cout<<"SUMDIFF4(g1, c2, "<<idxf(gi,2)<<", "<<idxf(gi,6)<<");" <<endl; A(2);
      cout << "}" << endl;
    }
    ldk = 3;
  }
  cout<<"} // end initial loop"<<endl;

  if ( ldk==ldn )  return;
  cout<<endl;

  for (  ; ldk<ldn;  ldk+=2)
  {
    ulong k   = 1 << ldk;
    ulong kh  = k >> 1;
    ulong k2  = k << 1;
    ulong k3  = k2 + k;
    ulong k4  = k2 << 1;

    cout<<"{ // -------- ldk=" << ldk << "  k4="<<k4<<endl;
    cout<<"Type f0, f1, f2, f3;"<<endl;
    for (double *fi=f, *gi=fi+kh;  fi<fn;  fi+=k4, gi+=k4)
    {
      cout<<"// do loop: "<<endl;

      cout<<"SUMDIFF4("<<idxf(fi,0)<<", "<<idxf(fi,k)<<", f0, f1);" <<endl; A(2);
      cout<<"SUMDIFF4("<<idxf(fi,k2)<<", "<<idxf(fi,k3)<<", f2, f3);" <<endl; A(2);

      cout<<"SUMDIFF4(f0, f2, "<<idxf(fi,0)<<", "<<idxf(fi,k2)<<");" <<endl; A(2);
      cout<<"SUMDIFF4(f1, f3, "<<idxf(fi,k)<<", "<<idxf(fi,k3)<<");" <<endl; A(2);

      cout<<"SUMDIFF4("<<idxf(gi,0)<<", "<<idxf(gi,k)<<", f0, f1);" <<endl; A(2);
      cout<<"f3 = SQRT2 * "<<idxf(gi,k3)<<";" <<endl; M(1);
      cout<<"f2 = SQRT2 * "<<idxf(gi,k2)<<";" <<endl; M(1);

      cout<<"SUMDIFF4(f0, f2, "<<idxf(gi,0)<<", "<<idxf(gi,k2)<<");" <<endl; A(2);
      cout<<"SUMDIFF4(f1, f3, "<<idxf(gi,k)<<", "<<idxf(gi,k3)<<");" <<endl; A(2);
    }
    cout<<"}"<<endl;
    cout<<endl;


    long double tt = M_PI/4/kh;
    long double s1, c1;
    cout<<"{ // kh="<<kh<<endl;
    cout<<"Type a, b, g0, f0, f1, g1, f2, g2, f3, g3;"<<endl;

    for (ulong i=1; i<kh; i++)
    {
      cout << "{ // ---- i=" << i << endl;
      sincos(tt*i, &s1, &c1);
      prsincos("c1", "s1", i, kh*4);

      double s2, c2;
      sincos(tt*2*i, &s2, &c2);
      int cs2 = ( fabs(s2*c2-0.5) < 1e-12 ? 1 : 0 );  // c2==s2==1/sqrt(2)
      //      cout << "// c2=" << c2 << "  s2=" << s2 << endl;
      //      cs2 = 0;  // PARAMETER: whether to avoid printing of sqrt(1/2)

      if ( cs2 )   cout << "// c2 = s2 = sqrt(1/2)" << endl;
      else         prsincos("c2", "s2", 2*i, kh*4);


      for (double *fi=f+i, *gi=f+k-i;  fi<fn;  fi+=k4, gi+=k4)
      {
        cout<<"// do loop II: "<<endl;

        if ( cs2 )
        {
          cout<<"SUMDIFF4("<<idxf(fi,k)<<", "<<idxf(gi,k)<<", a, b);"<<endl; A(2);
          cout<<"a *= SQRT1_2;"<<endl;  M(1);
          cout<<"b *= SQRT1_2;"<<endl;  M(1);
        }
        else
        {
          cout<<"CMULT6(s2, c2, "<<idxf(fi,k)<<", "<<idxf(gi,k)<<", b, a);" <<endl; M(4);A(2);
        }

        cout<<"SUMDIFF4("<<idxf(fi,0)<<", a, f0, f1);" <<endl; A(2);
        cout<<"SUMDIFF4("<<idxf(gi,0)<<", b, g0, g1);" <<endl; A(2);

        if ( cs2 )
        {
          cout<<"SUMDIFF4("<<idxf(fi,k3)<<", "<<idxf(gi,k3)<<", a, b);"<<endl; A(2);
          cout<<"a *= SQRT1_2;"<<endl;  M(1);
          cout<<"b *= SQRT1_2;"<<endl;  M(1);
        }
        else
        {
          cout<<"CMULT6(s2, c2, "<<idxf(fi,k3)<<", "<<idxf(gi,k3)<<", b, a);" <<endl; M(4); A(2);
        }

        cout<<"SUMDIFF4("<<idxf(fi,k2)<<", a, f2, f3);" <<endl; A(2);
        cout<<"SUMDIFF4("<<idxf(gi,k2)<<", b, g2, g3);" <<endl; A(2);

        cout<<"CMULT6(s1, c1, f2, g3, b, a);" <<endl; M(4);A(2);
        cout<<"SUMDIFF4(f0, a, "<<idxf(fi,0)<<", "<<idxf(fi,k2)<<");" <<endl; A(2);
        cout<<"SUMDIFF4(g1, b, "<<idxf(gi,k)<<", "<<idxf(gi,k3)<<");" <<endl; A(2);

        cout<<"CMULT6(c1, s1, g2, f3, b, a);" <<endl; M(4);A(2);
        cout<<"SUMDIFF4(g0, a, "<<idxf(gi,0)<<", "<<idxf(gi,k2)<<");" <<endl; A(2);
        cout<<"SUMDIFF4(f1, b, "<<idxf(fi,k)<<", "<<idxf(fi,k3)<<");" <<endl; A(2);
      }
      cout << "}" << endl;
    }
    cout<<"}"<<endl;
    cout<<endl;
  }
}
// ===================== end =====================


#include <stdlib.h>  // atol()
#include <stdio.h>  // sprintf()

int
main(int argc, char **argv)
{
    ulong ldn = 4;
    if ( argc>1 )  ldn = atol(argv[1]);
    ulong n = (1<<ldn);

    int scrt = 0;  // revbin_permute before, DIT (0) or after, DIF (1) main loop
    if ( argc>2 )  scrt = atol(argv[2]);

    int zp = 0;
    if ( argc>3 )  zp = atol(argv[3]);

    const char *fmt =  "template <typename Type>\ninline void\nfht_di%c_core_%lu(Type *f)";
    char fname[150];
    sprintf(fname, fmt, (scrt?'f':'t'), n);
//    cout << fname << endl;
//    prelude(n, "template <typename Type>\ninline void\nfht_(Type *f)", scrt);
    prelude(n, fname, scrt);

    if ( scrt )  revbin_permute(st,n);
    prfht(0, ldn, zp);

    finale(n, scrt);
}
// ===================== end =====================

// for ((i=1; i<=6; ++i)); do bin $i; done > shortfhtditcore.h
// for ((i=1; i<=6; ++i)); do bin $i 1; done > shortfhtdifcore.h
