
#include "jjassert.h"

#include "ntt.h"

//
// unoptimized radix 2 learner's versions
//

void 
ntt_dif2l(mod *f, ulong ldn, int is)
//
// decimation in frequency
// revbin_permutes data before exit
//
{
    ulong ldm;
    ulong  m,mh;

    ulong n=1<<ldn;
    const mod rn = (is>0 ? root(n) : invroot(n) );

    for (ldm=ldn; ldm>=1; --ldm)
    {
        m = (1<<ldm);            // m=2^ldm
	mh=(m>>1);             // mh=m/2

	ulong x=(1<<(ldn-ldm));
	mod dw = pow(rn,x);  // "=" exp(2*pi*i/n)^(j*2^(ldn-ldm))
	mod w=(mod::one);

        for (ulong j=0; j<mh; ++j)
        {
            for (ulong r=0; r<n; r+=m)
            {
		// u=f[t1]
		// v=f[t2]
		// f[t1]= u+v
		// f[t2]= (u-v)*exp(+-2*pi*i*j/m)

                ulong t1=r+j;        // index-help variables
                ulong t2=t1+mh;

                mod v = f[t2];
                mod u = f[t1];

                f[t1] += v;
                f[t2] = (u-v)*w;
            }

	    w *= dw;
        }
    }

    revbin_permute(f,n);            // rearrange function values
}
// =============== end =================


void 
ntt_dit2l(mod *f, int ldn, int is)
//
// decimation in time
// revbin_permutes data at entry
//
{
    ulong ldm,m,mh;

    ulong n=1<<ldn;
    revbin_permute(f,n); 

    const mod rn = (is>0?root(n):invroot(n));

    for (ldm=1; ldm<=(ulong)ldn; ++ldm)
    {
        m = (1<<ldm);            // m=2^ldm
        mh = (m>>1);             // mh=m/2

	ulong x=(1<<(ldn-ldm));
	mod dw = pow(rn,x);  // "=" exp(2*pi*i/n)^(j*2^(ldn-ldm))
	mod w=(mod::one);

        for (ulong j=0; j<mh; ++j)
        {
            for (ulong r=0; r<n; r+=m)
            {
                ulong t1 = r+j;
                ulong t2 = t1+mh;

                mod v = f[t2]*w;
                mod u = f[t1];

                f[t1] = u+v;
                f[t2] = u-v;
            }

	    w *= dw;
        }
    }
} 
// =============== end =================

