#include <math.h>
#include <iostream.h>


#include "builtin.h"
#include "mfft.h"
#include "modm.h"
#include "jjassert.h"

int
test_cnvl(int ldn)
{
    int ret = 0;

    cout << "\n testing fft-convolution: \n";
    long i;
    long n = (1<<ldn);
    mod  f[n],g[n];

    for(i=0; i<n; ++i)    f[i] = (uint)0;
    for(i=0; i<n/2; ++i)  f[i] = (uint)1;


    for (i=0; i<5; ++i)
    {
        if ( i>0 )  bogo_rand(f,n,mod::modulus);

        copy(f,g,n);

        slow_mod_convolution(f,f,n);
//        fft_mod_auto_convolution(f,ldn);

        fft_mod_auto_convolution(g,ldn);
//        fft_mod_auto_convolution_noncyclic(g,ldn);

        if ( diff_print(f,g,n,1) )  { cout << " (in i=" << i << ")\n\n"; ret++; }
        else  { cout << ".";  cout.flush(); }
    }

    return ret;
}
//-------------------------------------------

int
test_fft(int ldn)
{
    int ret = 0;

    cout << "\n testing ffts: \n";
    long i,k;
    long n = (1<<ldn);
    mod  f[n],g[n];

    int is = -1;
    mod x,w;
    mod rt = mod::root2pow( is*ldn );

    mod m = (uint)1;

    for (i=0; i<n; ++i)
//    i = 1;
    {
        for(k=0; k<n; ++k)  { f[k] = (uint)0; }  f[i] = m;

        if ( i>n/2 )  bogo_rand(f,n,mod::modulus);

        copy(f,g,n);

        /*
          // fft + fft^-1
        mod_fft_dif2(g,ldn,is);
        mod_fft_dif4(g,ldn,-is);
        multiply(g,n,inv(n));
        if ( diff_print(f,g,n,1) )  cout << " (in ii=" << i << ")\n";
        else  { cout << "%";  cout.flush(); }
        */


        // fft of delta peak at i
//        w = pow(rt,i);
//        x = m;
//        for(k=0; k<n; ++k)  { f[k] = x; x *= w; }
//        jjassert( x==m );
//        mod_sft(f,n,is);
//        mod_fft_dif2(f,ldn,-is);


        slow_mod_convolution(f,f,n);
//        is = -is;
        mod_sft(f,n,is);
//        mod_fft_dif2(f,ldn,is);
//        mod_fft_dif4(f,ldn,is);
//        mod_fft_dit2(f,ldn,is);
//        mod_fft_dit4(f,ldn,is);
//        mod_sft(f,n,is);
//        mod_fft_dif2_noncyclic(f,ldn,is);
//        multiply(f,n,inv((uint)n));
//        mod_sft(f,n,is);


//        mod_fft_dif2_noncyclic(g,ldn,is);
//        mod_fft_dif2_noncyclic(g,ldn,-is);
//        mod_fft_dif2(g,ldn,is);
//        mod_fft_dif2(g,ldn,-is);
//        mod_fft_dif4(g,ldn,is);
//        mod_fft_dit2(g,ldn,is);
//        mod_fft_dit4(g,ldn,is);
        mod_sft(g,n,is);
        for (k=0; k<n; ++k)  { g[k] *= g[k]; }
//        multiply(g,n,inv((uint)n));


        if ( diff_print(f,g,n,0) )  { ret++; cout << " (in i=" << i << ")\n"; }
        else                        { cout << ".";  cout.flush(); }

        cout << endl;
    }

    return  ret;
}
//-------------------------------------------


int
main()
{
//    mod_info();

    umod_t m = ipow(6,8)+1;
    //    umod_t m = mod::modulus;
    factorization ff(m);
    cout << "m=" << m << " == " << ff << endl;
    delete mod::mod_initializer;
    mod::mod_initializer = new mod_init(m,ff.prime);
    mod_info();

    //    exit(0);

    //    mod_info();

    uint ldn = 3;

    //    mod n = (uint)(1<<ldn);
    //    mod i = inv(n);


    cout << " testing... \n";

    int ret = 0;

//    ret += test_fft(ldn);
//    ret += test_fft(ldn+1);
    ret += test_cnvl(ldn);
    ret += test_cnvl(ldn+1);

    if ( ret )  cout << "\n !!! DESASTER !!! " << endl;
    else        cout << "\n all OK. fine." << endl;

    return 0;
}
//-------------------------------------------


