
#include <math.h>
#include <iostream.h>

#include "ntt.h"
#include "modm.h"
#include "jjassert.h"

int
test_fft(int ldn)
{
    int ret = 0;

    long i,k;
    long n = (1<<ldn);
    cout << "\n testing length-" << n << " fft: \n";
    mod  f[n],g[n];

    int is = +1;
    mod x,w;
    mod rt = mod::root2pow( is*ldn );

    mod m = (uint)1;

    int mi = 4;  if ( n>mi )  mi = n;
    for (i=0; i<mi; ++i)
//    i = 1;
    {
        for (k=0; k<n; ++k)  { f[k] = (uint)0; }
//        f[i] = m;
        for (k=0; k<n; ++k)  if ( (i+1)&(1<<k) )  f[k].x = 1;
        
//        if ( i>n/2 )  rand(f,n,mod::modulus);

        copy(f,g,n);

        /*
          // fft + fft^-1
        ntt_dif2(g,ldn,is);
        ntt_dif4(g,ldn,-is);
        multiply(g,n,inv(n));
        if ( diff_print(f,g,n,1) )  cout << " (in ii=" << i << ")\n";
        else  { cout << "%";  cout.flush(); }
        */


        // fft of delta peak at i
//        w = pow(rt,i);
//        x = m;
//        for (k=0; k<n; ++k)  { f[k] = x; x *= w; }
//        jjassert( x==m );
        slow_ntt(f,n,is);
        slow_ntt(f,n,-is);
//        ntt_dif2_noncyclic(f,ldn,is);
//        ntt_dif2_noncyclic(f,ldn,-is);
//        ntt_dif2(f,ldn,-is);


//        slow_mod_convolution(f,f,n);
//        is = -is;
//        ntt_dif2(f,ldn,is);
//        ntt_dif4(f,ldn,is);
//        ntt_dit2(f,ldn,is);
//        ntt_dit4(f,ldn,is);
//        slow_ntt(f,n,is);
//        ntt_dif2_noncyclic(f,ldn,is);
//        multiply(f,n,inv((uint)n));
//        slow_ntt(f,n,is);


        for (k=0; k<n; ++k)  g[k] *= n;
//        ntt_dif2_noncyclic(g,ldn,is);
//        ntt_dif2_noncyclic(g,ldn,-is);
//        ntt_dif2(g,ldn,is);
//        ntt_dif2(g,ldn,-is);
//        ntt_dif4(g,ldn,is);
//        ntt_dit2(g,ldn,is);
//        ntt_dit4(g,ldn,is);
//        slow_ntt(g,n,is);
//        for (k=0; k<n; ++k)  { g[k] *= g[k]; }
//        multiply(g,n,inv((uint)n));


        if ( diff_print(f,g,n,0) )  { ret++; }
        cout << " (in i=" << i << ")\n";
//        cout << endl;
    }

    return  ret;
}
//-------------------------------------------


int
test_cnvl(int ldn)
{
    int ret = 0;

    long i;
    long n = (1<<ldn);
    mod  f[n],g[n];

    for (i=0; i<n; ++i)  f[i] = (uint)0;
    for (i=0; i<2; ++i)  f[i] = (uint)1;

    cout << "\n testing length-" << n << " convolution: \n";

    for (i=0; i<1; ++i)
    {
        for (int k=0; k<n; ++k)  f[k] = mod::zero;
        f[i] = mod::one;
        
//        if ( i>0 )  rand(f,n/2,mod::modulus);
//        for (int k=0; k<n; ++k)  f[k] *= 1;

        copy(f,g,n);

        slow_mod_convolution(f,f,n);
        for (int k=0; k<n; ++k)  f[k] *= n;
//        ntt_auto_convolution(f,ldn);

        print("g:",g,n);
        int is = +1;
//        ntt_dif2_noncyclic(g,ldn,is);
        slow_ntt(g,n,is);
//        ntt_dif4(g,ldn,is);
        print("fft g:",g,n);
        for (int k=0; k<n; ++k)  g[k] *= g[k];
        print("(fft g)^2:",g,n);
//        ntt_dif4(g,ldn,-is);
        slow_ntt(g,n,-is);
        print("fft^-1 (fft g)^2:",g,n);

//        ntt_dif2_noncyclic(g,ldn,-is);
//        for (int k=0; k<n; ++k)  g[k].x >>= ldn;
        
//        ntt_auto_convolution(g,ldn);
//        ntt_auto_convolution_noncyclic(g,ldn);

        if ( diff_print(f,g,n,0) )  { cout << " (in i=" << i << ")\n\n"; ret++; }
        else  { cout << ".";  cout.flush(); }
    }

    return ret;
}
//-------------------------------------------


void
test_init()
{
    umod_t m[] = {
        2,4,8,16,256,
        3,9,
        8*9,
        35,
        5,25,125,
        10,20,80,800,
        6,12,18,96,
        17,257,
        1728,1729,
        0x7f000001, 0x7e000001, 0x78000001,
        0x7ffffe0000000001ULL,
        0x7fffe40000000001ULL,
        0x7fffe00000000001ULL,
        0x7fffcc0000000001ULL,
        0x7fff8c0000000001ULL,
        0x7ffedc0000000001ULL, //  ! YIKES !
        2130706433ULL*2130706433ULL,
        65537ULL*65537ULL,
        65537ULL*65537ULL*65537ULL,
        257ULL*257*257*257*257*257*257*2,
        2ULL*65537*65537*65537,
        2ULL*2130706433*2130706433,
        0};

    for (int k=0; m[k]; k++)
    {
        cout << "test_init(): m=" << m[k] << hex << " = 0x" << m[k] << dec << endl;    
        mod::reinit(m[k]);
        mod::info();
    }
}
//-------------------------------------------

void
test_egcd()
{
    smod_t k,j;
    smod_t u,v,u1,u2,d,g;

    u = 9223369837831520257LL;
    v = 7;

//    d = egcd(v,u,u2,u1);
//    jjassert( u1*u+u2*v==d );

    d = egcd(u,v,u1,u2);
    jjassert( u1*u+u2*v==d );
    
    for (k=1; k<111; k++)
    {
        for (j=1; j<=k; j++)
        {
            u = k;
            v = j;
//            cout << "\n test_egcd():  u=" << u << "  v=" << v << endl;

            d = egcd(v,u,u2,u1);
            jjassert( u1*u+u2*v==d );
            d = egcd(u,v,u1,u2);
            jjassert( u1*u+u2*v==d );
            g = gcd((umod_t)u,(umod_t)v);
            jjassert( d==g );
        }
    }
}//-------------------------------------------



int
main()
{
//    test_egcd();  exit(0);

    mod::info();
    test_init();  exit(0);
    
    mod::reinit( 16 );
    mod::info();

//    cout << " mod::xfact=" << mod::xfact << endl;
//    mod z = 1;  assert(0);

//    mod_mandel(65);

//    for (uint k=0; k<mod::modulus; k++)  cout << k << "^2=" << mod(k)*mod(k) << endl;
//    exit(0);

    //    mod::info();

    uint ldn = 1;

    //    mod n = (uint)(1<<ldn);
    //    mod i = inv(n);


    int ret = 0;
    ret += test_fft(ldn);
//    ret += test_fft(ldn+1);
//    ret += test_cnvl(ldn);
//    ret += test_cnvl(ldn+1);

    if ( ret )  cout << "\n !!! DESASTER !!! " << endl;
    else        cout << "\n all OK. fine." << endl;

    return 0;
}
//-------------------------------------------


