
#include "mfft.h"
#include "modm.h"
#include "jjassert.h"

#define MOD_FFT(x,y,z)   mod_fft_dit2(x,y,z) 



void
fft_mod_convolution(mod *f, mod *g, ulong ldn)
//
// _cyclic_ convolution 
// (use zero padded data for usual co.)
//
// result in g[]
{
    jjassert( ldn<=mod::max2pow );  // n must be invertible

    jjassert( f!=g );

    int is=+1;

    MOD_FFT(f,ldn,is);
    MOD_FFT(g,ldn,is);

    const ulong n = (1<<ldn);
    for (ulong i=0; i<n; ++i)  g[i]*=f[i];

    MOD_FFT(g,ldn,-is);

    multiply(g,n,inv(n));
}
//============== end FFT_MOD_CONVOLUTION =================



void
fft_mod_auto_convolution(mod *f, ulong ldn)
//
// _cyclic_ (self-)convolution 
// (use zero padded data for usual co.)
//
{
    jjassert( ldn<=mod::max2pow );  // n must be invertible

    int is = +1;

    MOD_FFT(f,ldn,is);

    const ulong n = (1<<ldn);
    for (ulong i=0; i<n; ++i)  f[i]*=f[i];

    MOD_FFT(f,ldn,-is);

    multiply(f,n,inv(n));
}
//============== end FFT_MOD_AUTO_CONVOLUTION =================


void
fft_mod_convolution(mod *f, mod *g, mod *h, ulong ldn)
//
// _cyclic_ convolution of f[] and g[]
// (use zero padded data for usual co.)
//
// result in h[]
// f[] and g[] are left intact
{
    jjassert( ldn<=mod::max2pow );  // n must be invertible

    jjassert( f!=g );

    int is=+1;

    MOD_FFT(f,ldn,is);
    MOD_FFT(g,ldn,is);

    const ulong n = (1<<ldn);
    for(ulong i=0; i<n; ++i)  h[i]= f[i]*g[i];

    MOD_FFT(h,ldn,-is);

    multiply(h,n,inv(n));

    MOD_FFT(f,ldn,-is);
    multiply(f,n,inv(n));
    MOD_FFT(g,ldn,-is);
    multiply(g,n,inv(n));
}
//============== end FFT_MOD_CONVOLUTION =================



// the rest is for hfloat:

#include <math.h>  // for floor()

void
double_to_mod_by_force(double *f, ulong n)
//
// just forces doubles into mods
//
{
    jjassert( sizeof(mod)==sizeof(double) );

    mod *m = (mod *)f;

    for (ulong i=0; i<n; ++i)
    {
        umod_t t = (umod_t)floor(f[i]+0.5);
//        jjassert( double(t)==f[i] );
        m[i] = t;
    }
}
//-------------------------------------------

void
mod_to_double_by_force(mod *m, ulong n)
//
// just forces mods into doubles
//
{
    jjassert( sizeof(mod)==sizeof(double) );

    double *f=(double *)m;

    for (ulong i=0; i<n; ++i)  f[i] = make_double(m[i]);
//    {
//        double t = (double)(m[i].x);
////        jjassert( umod_t(t)==m[i] );
//        f[i] = t;
//    }
}
//-------------------------------------------

void
fft_mod_convolution(double *f, double *g, ulong ldn)
{
    const ulong n=(1<<ldn);

    double_to_mod_by_force(f,n);
    double_to_mod_by_force(g,n);

    fft_mod_convolution((mod *)f,(mod *)g,ldn);

    mod_to_double_by_force((mod *)f,n);
    mod_to_double_by_force((mod *)g,n);
}
//-------------------------------------------

void
fft_mod_auto_convolution(double *f, ulong ldn)
{
    const ulong n=(1<<ldn);

    double_to_mod_by_force(f,n);

    fft_mod_auto_convolution((mod *)f,ldn);

    mod_to_double_by_force((mod *)f,n);
}
//-------------------------------------------

