
#include <assert.h>

#include "mfft.h"
#include "modaux.h"

#define MOD_FFT(x,y,z)   mod_fft_dit2(x,y,z) 



void
fft_mod_convolution(mod *f, mod *g, ulong ldn)
//
// _cyclic_ convolution 
// (use zero padded data for usual co.)
//
// result in g[]
{
    assert( f!=g );

    int is=+1;

    MOD_FFT(f,ldn,is);
    MOD_FFT(g,ldn,is);

    const ulong n = (1<<ldn);
    for(ulong i=0; i<n; ++i)  g[i]*=f[i];

    MOD_FFT(g,ldn,-is);

    multiply(g,n,inv(n));
}
//============== end FFT_MOD_CONVOLUTION =================


void
fft_mod_convolution(mod *f, mod *g, mod *h, ulong ldn)
//
// _cyclic_ convolution of f[] and g[]
// (use zero padded data for usual co.)
//
// result in h[]
// f[] and g[] are left intact
{
    assert( f!=g );

    int is=+1;

    MOD_FFT(f,ldn,is);
    MOD_FFT(g,ldn,is);

    const ulong n = (1<<ldn);
    for(ulong i=0; i<n; ++i)  h[i]= f[i]*g[i];

    MOD_FFT(h,ldn,-is);

    multiply(h,n,inv(n));

    MOD_FFT(f,ldn,-is);
    multiply(f,n,inv(n));
    MOD_FFT(g,ldn,-is);
    multiply(g,n,inv(n));
}
//============== end FFT_MOD_CONVOLUTION =================


void
fft_mod_auto_convolution(mod *f, ulong ldn)
//
// _cyclic_ (self-)convolution 
// (use zero padded data for usual co.)
//
{
    int is=+1;

    MOD_FFT(f,ldn,is);

    const ulong n = (1<<ldn);
    for(ulong i=0; i<n; ++i)  f[i]*=f[i];

    MOD_FFT(f,ldn,-is);

    multiply(f,n,inv(n));
}
//============== end FFT_MOD_AUTO_CONVOLUTION =================


void
slow_mod_convolution(mod *f, mod *g, ulong n)
//
// _cyclic_ convolution 
// (use zero padded data for usual co.)
//
{
    //    const long n=(long)nu;

    mod *z=new mod[n];

    for(ulong tau=0; tau<n; ++tau)
    {
	mod t; 
	t.set_x(0);
	for(ulong k=0; k<n; ++k) 
	{
	    ulong k2;
	    if( tau>=k )  k2=tau-k;
	    else          k2=n+tau-k;

	    //	    f[k2].check();
	    //	    g[k].check();

	    mod a = f[k2]*g[k]; 
	    //	    a.check();

	    t += a;
	    //	    t.check();
	}

        z[tau] = t;
	//	z[tau].check();
    }

    for(ulong k=0; k<n; ++k)  g[k]=z[k];

    delete [] z;
}
//============== end SLOW_MOD_AUTO_CONVOLUTION =================


// the rest is for hfloat:

#include <math.h>  // for floor()

void
double_to_mod_by_force(double *f, ulong n)
//
// just forces doubles into mods
//
{
    assert( sizeof(mod)==sizeof(double) );

    mod *m=(mod *)f;

    for (ulong i=0; i<n; ++i)  m[i].set_x((mod_t)floor(f[i]+0.5));
}
//-------------------------------------------

void
mod_to_double_by_force(mod *m, ulong n)
//
// just forces mods into doubles
//
{
    assert( sizeof(mod)==sizeof(double) );

    double *f=(double *)m;

    for (ulong i=0; i<n; ++i)  f[i]=(double)m[i].get_x();
}
//-------------------------------------------

void
fft_mod_convolution(double *f, double *g, ulong ldn)
{
    const ulong n=(1<<ldn);

    double_to_mod_by_force(f,n);
    double_to_mod_by_force(g,n);

    fft_mod_convolution((mod *)f,(mod *)g,ldn);

    mod_to_double_by_force((mod *)f,n);
    mod_to_double_by_force((mod *)g,n);
}
//-------------------------------------------

void
fft_mod_auto_convolution(double *f, ulong ldn)
{
    const ulong n=(1<<ldn);

    double_to_mod_by_force(f,n);

    fft_mod_auto_convolution((mod *)f,ldn);

    mod_to_double_by_force((mod *)f,n);
}
//-------------------------------------------
