
#include <iostream.h>
#include <math.h>

#include "mod.h"
#include "modm.h"
#include "jjassert.h"


const int vb = 0;  // set !=0 to make initialisation verbose


mod_init::~mod_init()
{
    delete [] mod::root_2pow;
    delete [] mod::invroot_2pow;
    mod::root_2pow = NULL;
    mod::invroot_2pow = NULL;
}
//======== end ~MOD_INIT() ==========


mod_init::mod_init(umod_t m, umod_t *facts)
{
    //    uint xx=
    erastothenes(66000);  // setup primetable  arg>50, use 66000 or so
    //    cout << xx << " primes ... \n"; 

    mod::modulus = m;
    mod_modulus = mod::modulus;

    mod::zero = (uint)0;
    mod::one =  (uint)1;
    mod::minus_one = m-1;  // will be m-1

    mod::mbitsd = log((double)m)/log(2.0);
    mod::mbits  = (uint)ceil(mod::mbitsd);

    mod::m1dd = (long double)1/m;
    mod_m1dd = mod::m1dd;

    // only after this point we can multiply mods !


    mod::mfact.make_factorization(m,facts);
    jjassert( mod::mfact.is_factorization_of(m) );

    if (vb )  mod_info1();

    // +++++ have modulus


    if ( mod::mfact.is_prime() )
    {
        mod::maxorder = m-1;
        mod::phi = m-1;
    }
    else
    {
//        if  ( mod:modulus_cyclic() )
//        {
            mod::maxorder = find_maxorder(); // needs mod::mfact;
//        }
//        else
//        {
//            umod_t test_find_maxorder();  // decl.
//            mod::maxorder = test_find_maxorder(); // needs mod::mfact;
//        }

        mod::phi = phi(mod::mfact);
    }

    mod::xfact.make_factorization(mod::maxorder);
    jjassert( mod::xfact.is_factorization_of(mod::maxorder) );

    mod::pfact.make_factorization(mod::phi);
    jjassert( mod::pfact.is_factorization_of(mod::phi) );

    if (vb )  mod_info2();

    // +++++ have maxorder, phi


#if  ( USE_LEQ_63BIT_MODULUS )
    jjassert( ((smod_t)m>0)*(int)" modulus must have <=63bits " );
#endif


    if ( mod::modulus_prime() )
        mod::primroot = find_primitive_root();
    else
        mod::primroot = find_maxorder_element();
//    cout << "\n XXX primroot=" << mod::primroot << endl;

    umod_t rr = order(mod::primroot);
    jjassert( rr != 0 );
    jjassert( rr == mod::maxorder );

    mod::inverse_primroot = inv(mod::primroot);

    jjassert( (mod::primroot*mod::inverse_primroot)==mod::one );

    if (vb )  mod_info3();

    // +++++ have element of maximal order (primitive root if m prime)


//    jjassert( mod::xfact.prime[0]==2 );  // needed for (radix 2) ffts

    mod t,ti;
    if ( mod::modulus_cyclic() )
    {
        mod::max2pow = mod::xfact.exponent[0];

        umod_t z = mod::maxorder/pow2(mod::max2pow);
        t = pow(mod::primroot,z);
        ti = pow(mod::inverse_primroot,z);

    }
    else
    {
        umod_t w = m-1, x = m-1;
        uint k = 0;
        while ( is_quadratic_residue(x,mod::mfact) )
        {
            k++;
            w = sqrt_modf(x,mod::mfact);
            x = w;
        }
        mod::max2pow = k+1;

        t = w;
        ti = inv(w);
    }

    int m2 = mod::max2pow;
    mod::root_2pow =    new mod[m2+1];
    mod::invroot_2pow = new mod[m2+1];

    mod::root_2pow[m2] =    t;
    mod::invroot_2pow[m2] = ti;


    for (int k=mod::max2pow-1; k>=0; --k)
    {
        t = mod::root_2pow[k+1];
        t *= t;
        mod::root_2pow[k] = t;
        jjassert( order(t)==pow2((uint)k) );

        ti = mod::invroot_2pow[k+1];
        ti *= ti;
        mod::invroot_2pow[k] = ti;
        jjassert( order(ti)==pow2((uint)k) );

        jjassert( t*ti==mod::one );
    }

    for (int k=0; k<=(int)mod::max2pow; ++k)
    {
        jjassert( order(mod::root2pow(k))==pow2((uint)k) );
        jjassert( order(mod::root2pow(-k))==pow2((uint)k) );
    }



    if (vb )  mod_info4();

    // +++++ from now on we can do ffts ...



    // ------- montgomery section:
    /*
    ulong mb = 8*sizeof( umod_t );
    mod::mg_bits = mb;
    if ( vb )  cout << " mod_initialiser: mg_bits=" << mod::mg_bits << endl;

    umod_t mm = 0;
    for (ulong k=0; k<mb; ++k)
    {
	mm <<= 1;
	mm += 1;
    }

    mod::mg_mask = mm;
    if ( vb )  cout << " mod_initialiser: mg_mask=" << mod::mg_mask << endl;

    umod_t mr = montgomery_r(mb,m);
    mod::mg_r = mr;
    if ( vb )  cout << " mod_initialiser: mg_r=" << mod::mg_r << endl;

    umod_t mrp = montgomery_r_prime(mb,m);
    mod::mg_r_prime = mrp;
    if ( vb )  cout << " mod_initialiser: mg_r_prime=" << mod::mg_r_prime << endl;
    jjassert( mod((umod_t)1)==mod(mr)*mod(mrp) );

    umod_t mmp = montgomery_m_prime(mb,m,mm);
    mod::mg_m_prime = mmp;
    if ( vb )  cout << " mod_initialiser: mg_m_prime=" << mod::mg_m_prime << endl;
    */

    if (vb )  mod_info99();
}
//================= end MOD_INIT() ===================
