
#include "ntt.h"
#include "modm.h"
#include "jjassert.h"

#define MOD_FFT(x,y,z)   ntt_dit2(x,y,z)


inline void
check_two_invertible(ulong ldn)
{
    jjassert2( mod::modulus & 1, "need odd modulus to do length 2**ldn convolution" );

    ldn = 0;
}
//============== end =================


void
ntt_auto_convolution(mod *f, ulong ldn)
//
// _cyclic_ (self-)convolution
// (use zero padded data for usual co.)
//
{
    check_two_invertible(ldn);

    int is = +1;

    MOD_FFT(f,ldn,is);

    const ulong n = (1<<ldn);
    for (ulong i=0; i<n; ++i)  f[i] *= f[i];

    MOD_FFT(f,ldn,-is);

    multiply(f,n,inv(n));
}
//============== end =================


void
ntt_convolution(mod *f, mod *g, ulong ldn)
//
// _cyclic_ convolution
// (use zero padded data for usual co.)
//
// result in g[]
{
    check_two_invertible(ldn);

    jjassert( f!=g );

    int is=+1;

    MOD_FFT(f,ldn,is);
    MOD_FFT(g,ldn,is);

    const ulong n = (1<<ldn);
    for (ulong i=0; i<n; ++i)  g[i] *= f[i];

    MOD_FFT(g,ldn,-is);

    multiply(g,n,inv(n));
}
//============== end =================


void
ntt_convolution(mod *f, mod *g, mod *h, ulong ldn)
//
// _cyclic_ convolution of f[] and g[]
// (use zero padded data for usual co.)
//
// result in h[]
// f[] and g[] are left intact
{
    check_two_invertible(ldn);

    jjassert( f!=g );

    int is=+1;

    MOD_FFT(f,ldn,is);
    MOD_FFT(g,ldn,is);

    const ulong n = (1<<ldn);
    for (ulong i=0; i<n; ++i)  h[i]= f[i] * g[i];

    MOD_FFT(h,ldn,-is);

    multiply(h,n,inv(n));

    MOD_FFT(f,ldn,-is);
    multiply(f,n,inv(n));
    MOD_FFT(g,ldn,-is);
    multiply(g,n,inv(n));
}
//============== end =================



// the rest is for hfloat:  // jjnote: code LIMB_to_mod()

#include <math.h>  // for floor()

void
double_to_mod_by_force(double *f, ulong n)
//
// just forces doubles into mods
//
{
    jjassert( sizeof(mod)==sizeof(double) );

    mod *m = (mod *)f;
    for (ulong i=0; i<n; ++i)  m[i].x = (umod_t)floor(f[i]+0.5);
}
//-------------------------------------------

void
mod_to_double_by_force(mod *m, ulong n)
//
// just forces mods into doubles
//
{
    jjassert( sizeof(mod)==sizeof(double) );

    double *f=(double *)m;
    for (ulong i=0; i<n; ++i)  f[i] = (double)(m[i].x);
}
//-------------------------------------------

void
ntt_convolution(double *f, double *g, ulong ldn)
{
    const ulong n=(1<<ldn);

    double_to_mod_by_force(f,n);
    double_to_mod_by_force(g,n);

    ntt_convolution((mod *)f,(mod *)g,ldn);

    mod_to_double_by_force((mod *)f,n);
    mod_to_double_by_force((mod *)g,n);
}
//-------------------------------------------

void
ntt_auto_convolution(double *f, ulong ldn)
{
    const ulong n=(1<<ldn);

    double_to_mod_by_force(f,n);

    ntt_auto_convolution((mod *)f,ldn);

    mod_to_double_by_force((mod *)f,n);
}
//-------------------------------------------

