
#include <iostream.h>
#include <math.h>

#include "mod.h"
#include "modm.h"
#include "jjassert.h"


#define  INIT_DEBUG  0  // 0 (default) or 1 (for debug)
#if  ( INIT_DEBUG==1 )
#warning 'FYI:  INIT_DEBUG == 1'
#endif


uint count_bits(uint64 m)
{
    uint k = 0;
    while ( m )  { k++; m>>=1; }
    return k;
}
//======== end ==========



mod_init::~mod_init()
{
    delete [] mod::root_2pow;
    delete [] mod::invroot_2pow;
    mod::mm1 = 0;
    mod::m1dd = 0.0;
    mod::modfact.reset();

    mod::mbitsd = 0.0;
    mod::mbits = 0;

    mod::maxorder = 0;
    mod::xfact.reset();
    mod::max2pow = 0;

    mod::phi = 0;
    mod::phifact.reset();

    mod::maxordelem.x = 0;
    mod::invmaxordelem.x = 0;
    mod::zero.x = 0;
    mod::one.x = 0;
    mod::two.x = 0;
    mod::minus_one.x = 0;

    mod::root_2pow = 0;
    mod::invroot_2pow = 0;
}
//======== end ==========


mod_init::mod_init(umod_t m, umod_t *primes/*=0*/)
{
    mod::modulus = m;
    mod::m1dd = (ldouble)1/(ldouble)m;

#if  ( INIT_DEBUG==1 )
    mod_info0(m);
#endif

#if  ( !USE_64BIT_MODULUS )
    jjassert2( !(m&((umod_t)1<<63)) ,
               " modulus must have less than 64 bits " );
#endif

#if  ( USE_LEQ_62BIT_MODULUS )
    jjassert2( !(m&((umod_t)3<<62)),
               " modulus must have less than 62 bits " );
#endif

    // +++++ only after this point we can multiply mods !
#if  ( INIT_DEBUG==1 )
    cout << "MOD_INIT(): can multiply mods" << endl;
#endif

    mod::mm1 = m-1;

    mod::zero = (uint)0;
    mod::one =  (uint)1;
    if ( m>2 ) mod::two = (uint)2;
    else       mod::two = (uint)0;
    mod::minus_one = m-1;

    umod_t mb = count_bits(m);
    mod::mbits = mb;
    double mbd = log((double)m)/log(2.0);
    if ( (umod_t(1)<<(mb-1))==m )  mbd = mb;
    mod::mbitsd = mbd;

    mod::modfact.make_factorization(m,primes);

    jjassert2( mod::modfact.is_factorization_of(m),
               "factorization of the modulus failed" );


    // +++++ have modulus, modfact
#if  ( INIT_DEBUG==1 )
    cout << "MOD_INIT(): have modulus, modfact" << endl;
    mod_info1();
#endif


    if ( mod::modfact.is_prime() )
    {
        mod::phi = m-1;
    }
    else
    {
        mod::phi = phi(mod::modfact);
    }

    mod::phifact.make_factorization(mod::phi);
    jjassert2( mod::phifact.is_factorization_of(mod::phi),
               "factorization of phi failed" );


    // +++++ have phi, phifact
#if  ( INIT_DEBUG==1 )
    cout << "MOD_INIT(): have phi, phifact" << endl;
#endif

    mod::maxorder = maxorder_mod(mod::modfact);

    mod::xfact.make_factorization(mod::maxorder);
    jjassert2( mod::xfact.is_factorization_of(mod::maxorder),
               "factorization of the maximal order failed");


    // +++++ have maxorder, xfact
#if  ( INIT_DEBUG==1 )
    cout << "MOD_INIT(): have maxorder, xfact" << endl;
    mod_info2();
#endif


    mod::maxordelem = maxorder_element_mod(mod::modfact, mod::phifact);

    umod_t rr = order(mod::maxordelem);
    jjassert2( rr != 0, "oops, order of primitive root is ==0" );
    jjassert2( rr == mod::maxorder,
               "oops, order of primitive root is != maxorder" );

    mod::invmaxordelem = inv(mod::maxordelem);
    jjassert2( (mod::maxordelem * mod::invmaxordelem)==mod::one,
               "oops, inverse(primroot)*primroot is != 1" );


    // +++++ have element of maximal order (primitive root if m cyclic)
#if  ( INIT_DEBUG==1 )
    cout << "MOD_INIT(): have element of maximal order" << endl;
    mod_info3();
#endif

    bool fftq = ( mod::modfact.prime[0]!=2 );

    mod t, ti;
    if ( fftq )
    {
        long ex2 = 1;
        umod_t w, z = mod::mm1;

        do
        {
            w = sqrt_modf(z,mod::modfact);
//            cout << "sqrt: " << ex2 << ": " << z << "  " << w << endl;
            if ( w )
            {
                z = w;
                ++ex2;
            }
        }
        while ( w );

        mod::max2pow = ex2;
        t = z;
        ti = inv_mod(t.x, m);
    }
    else
    {
        mod::max2pow = mod::xfact.exponent(2);
        umod_t z = mod::maxorder/pow2(mod::max2pow);
        t = pow(mod::maxordelem,z);
        ti = pow(mod::invmaxordelem,z);
    }

    int m2 = mod::max2pow;
    mod::root_2pow =    new mod[m2+1];
    mod::invroot_2pow = new mod[m2+1];

    mod::root_2pow[m2] =    t;
    mod::invroot_2pow[m2] = ti;


    for (int k=m2-1; k>=0; --k)
    {
        t = mod::root_2pow[k+1];
        t *= t;
        mod::root_2pow[k] = t;

        ti = mod::invroot_2pow[k+1];
        ti *= ti;
        mod::invroot_2pow[k] = ti;
    }

    for (int k=0; k<=m2; ++k)
    {
        umod_t r, p2k = pow2((uint)k);

        r = order(mod::root2pow(k));
        jjassert2( r==p2k, "order(root_2pow(k)) is != 2**k" );

        r = order(mod::root2pow(-k));
        jjassert2( r==p2k, "order(root2pow(-k)) is != 2**k" );

        mod t = (mod::root2pow(k))*(mod::root2pow(-k));
        jjassert2( t==mod::one,  "root2pow(k) * root2pow(-k) is != 1" );
    }


    // +++++ from now on we can do ffts ...
#if  ( INIT_DEBUG==1 )
    cout << "MOD_INIT(): can do ffts" << endl;
    mod_info4();
#endif


    // ------- montgomery section:
    /*
    ulong mb = 8*sizeof( umod_t );
    mod::mg_bits = mb;
    if ( vb )  cout << " mod_initialiser: mg_bits=" << mod::mg_bits << endl;

    umod_t mm = 0;
    for (ulong k=0; k<mb; ++k)
    {
	mm <<= 1;
	mm += 1;
    }

    mod::mg_mask = mm;
    if ( vb )  cout << " mod_initialiser: mg_mask=" << mod::mg_mask << endl;

    umod_t mr = montgomery_r(mb,m);
    mod::mg_r = mr;
    if ( vb )  cout << " mod_initialiser: mg_r=" << mod::mg_r << endl;

    umod_t mrp = montgomery_r_prime(mb,m);
    mod::mg_r_prime = mrp;
    if ( vb )  cout << " mod_initialiser: mg_r_prime=" << mod::mg_r_prime << endl;
    jjassert( mod((umod_t)1)==mod(mr)*mod(mrp) );

    umod_t mmp = montgomery_m_prime(mb,m,mm);
    mod::mg_m_prime = mmp;
    if ( vb )  cout << " mod_initialiser: mg_m_prime=" << mod::mg_m_prime << endl;
    */

#if  ( INIT_DEBUG==1 )
    mod_info99();
#endif
}
//================= end ===================
