
#include <math.h>
#include <iostream.h>
#include <iomanip.h>

#include "mod.h"
#include "modm.h"
#include "auxbit.h"

umod_t
strtol64(const char *str)
{
    umod_t z = 0;
    int k = 0;
    umod_t rx = 10;

    while ( str[k]==' ' )  k++;

    if ( str[0]=='0' && str[1]=='x' )
    {
        rx = 16;
        k += 2;
    }

    while ( str[k] )
    {
        z *= rx;
        unsigned char c = str[k];
        if ( c>='0' && c<='9' )  c -= '0';
        else
        {
            if ( rx==10 )  assert( 0*(int)"error in (dec) strtol64()" );
            if ( c>='a' && c<='f' )  c -= 'a';
            else
                if ( c>='A' && c<='F' )   c -= 'A';
                else  assert( 0*(int)"error in (hex) strtol64()" );

            c += 10;
        }

        z += c;
        k++;
    }

    return z;
}
//-------------------------------------------


const char *
dual(umod_t n)
{
    static const int b = 64;
    static char str[b+1];
    int k = 0;
    while ( k<b && (smod_t)n>=0 )
    {
        str[k] = '.';
        n<<=1;
        k++;
    }

    while ( k<b )
    {
        str[k] = ((smod_t)n<0?'1':'0');
        n <<= 1;
        k++;
    }

    return str;
}
//-------------------------------------------

void
rmpass(const umod_t n, const umod_t a)
{
    umod_t q;
    int t;
    n2qt(n,q,t);
    cout << setbase(16) << endl;
    cout << "n = " << n << "  t=" << t << endl;
    cout << "a = " << a << "   n-a = " << n-a << endl;
    umod_t r = order(a);
    umod_t qr;  int tr; n2qt(r+1,qr,tr);
    cout << "#(a) = " << r << "  = " << qr << " * 2^" << tr << endl; 
    umod_t b = pow_mod(a,q,n);
    cout << " 0:  b = " << setw(16) << b << " = " << dual(b) << endl;
    if ( 1==b )  goto passed;

    // squarings:
    {
        int e = 1;
        while ( (b!=1) && (b!=(n-1)) && (e<t) )
        {
            b = mul_mod(b,b,n);
            e++;
            cout << setw(2) << e
                 << ":  b = " << setw(16) << b
                 << " = " << dual(b) << endl;
        }
    }

    if ( b!=(n-1) )
    {
        cout << " COMPOSITE " << endl;
        goto done;
    }

 passed:
    cout << " prime " << endl;

 done:
    cout << setbase(10) << endl;
}
//-------------------------------------------


void
mont(umod_t x, umod_t y, umod_t m, umod_t ra, umod_t rb)
{
    umod_t r = ra * rb;
    cout << " ----- mont(): ----- " << endl;
    cout << "  x=" << x;
    cout << "  y=" << y;
    cout << "  m=" << m << endl;
    cout << "  r=" << r << "  ra=" << ra << "  rb=" << rb << endl;
    cout << endl;

    cout << " ---  mp = r-m^-1  r1 = r^-1 --- " << endl;
    assert( 1==gcd(m,r) );
    umod_t mp = r - inv_mod(m,r);
    umod_t mpa = mp % ra,  mpb = mp % rb;
    cout << "  mp=" << mp << "  mpa=" << mpa << "  mpb=" << mpb << endl;
    cout << "  cf. mpa=" << ra-inv_mod(m,ra) << "  mpb=" << rb-inv_mod(m,rb) << endl;

    umod_t r1 = inv_mod(r,m);
    umod_t r1a = inv_mod(ra,m),  r1b = inv_mod(rb,m);
    cout << "  r1=" << r1 << "  r1a=" << r1a << "  r1b=" << r1b << endl;
//    cout << "  cf. r1a=" << inv_mod(ra,m) << "  r1b=" << inv_mod(rb,m) << endl;
    cout << endl;

    cout << " ---  xp = x * r etc. --- " << endl;
    umod_t xp = mul_mod(x,r,m);
    umod_t yp = mul_mod(y,r,m);
    umod_t xpa = mul_mod(x,ra,m),  xpb = mul_mod(x,rb,m);
    umod_t ypa = mul_mod(y,ra,m),  ypb = mul_mod(y,rb,m);
    cout << "  xp=" << xp << "  xpa=" << xpa << "  xpb=" << xpb << endl;
    cout << "  yp=" << yp << "  ypa=" << ypa << "  ypb=" << ypb << endl;
    cout << endl;

    cout << " ---  p = xp * yp --- " << endl;
    umod_t p = xp * yp;
    umod_t pa = xpa*xpa,  pb = xpb*ypb;
    cout << "p=" << p << "  pa=" << pa << "   pb=" << pb << endl;

    cout << " ---  v = p * mp --- " << endl;
    umod_t v = mul_mod(p,mp,r);
    umod_t va = mul_mod(pa,mpa,ra),  vb = mul_mod(pb,mpb,rb);
    cout << "v=" << v << "  va=" << va << "  vb=" << vb << endl;

    cout << " ---  t = v * m --- " << endl;
    umod_t tn = p+v*m;
    umod_t tna = tn % (ra*ra),  tnb = tn % (rb*rb);
    cout << "tn=" << tn << "  tna=" << tna << "  tnb=" << tnb << endl;
    assert( r==gcd(tn,r) );
    assert( ra==gcd(tna,ra) );
    assert( rb==gcd(tnb,rb) );

    umod_t t = tn / r;
    umod_t ta = t / ra,  tb = t / rb;
    cout << "  t=" << t << "  ta=" << ta << "  tb=" << tb << endl;
    cout << endl;
//    while ( t>=m )  {  cout << " ! t=" << t << endl;  t -= m;  }

    umod_t zp = t,  zpa = ta,  zpb = tb;;
    umod_t z = mul_mod(zp,r1,m);
    cout << "  zp=" << zp << "  zpa=" << zpa << "  zpb=" << zpb << endl;
    cout << "  cf  zp%ra=" << zp%ra << "  zp%rb=" << zp%rb << endl;
    cout << endl;

    umod_t xy = mul_mod(x,y,m);
    umod_t xyp = mul_mod(xy,r,m);
    cout << "xy=" << xy << "  z=" << z;
    if ( xy!=z )  cout << "   ***** OOPS";
    cout << endl;
    cout << "xyp=" << xyp << "  zp=" << zp << "  zp%m=" << zp%m;
    if ( xyp!=zp )  cout << "   ***** ";
    cout << endl;
//    cout << "xyp%r=" << xyp%r << "  zp%r=" << zp%r;  cout << endl;

    cout << endl;
}
//-------------------------------------------

void
do_mont()
{
    umod_t x = 37,  y = 87;
    umod_t m = 257;

    umod_t ra = 10, rb = 41;
    mont(x,y,m,ra,rb);

//    umod_t r[] = {m+1, m-1, 100, 11, 10, 9, 5, 4, 3, 2, 0};
//    for (int k=0; r[k]; ++k)  if ( 1==gcd(r[k],m) )  mont(x,y,m,r[k]);
}
//-------------------------------------------

int
main(int argc, char **argv)
{
    cout << __FILE__ << "  .start. " << endl;

//    for (ulong i=0; i<155; ++i)
//    {
////        if ( is_pow_of_2(i) )  cout << "is_pow_of_2(" << i << ")" << endl;
////        if ( one_bit_q(i) )  cout << "one_bit_q(" << i << ")" << endl;
//        cout << i << ": lo=" << lowest_bit(i) << "  hi=" << highest_bit(i) << endl;
//    }
//    exit(0); ////

    umod_t m = mod::modulus;
    if ( argc>1 )
    {
        m = strtol64(argv[1]);
        mod::reinit(m);
    }

//    do_mont();  exit(0); ////


    mod::print_info();


    exit(0); ////

    umod_t a = 2;
    if ( argc>2 )  a = strtol64(argv[2]);

    rmpass(m,a);  exit(0); ////


    umod_t ldm2, m2, u;
    ldm2 = ld(m);
    if ( (m-1)!=(1ULL<<ldm2) )  ldm2++;
    m2 = (1ULL<<ldm2);
    u = m2 - m + 1;

    factorization fact(u);
    cout << "m = 2^" << ldm2 << "-" << fact << "+1" << endl;
    assert( m == (m2 - u + 1) );

//    return 0;

    for (umod_t k=2, j=1;  k<m;  k*=2, ++j)
    {
        umod_t r = order(k);
        umod_t rr = r;
        int p = 0;
        while ( !(rr&1) && rr )
        {
            p++;
            rr >>= 1;
        }

        umod_t f = r/((umod_t)1<<p);
        cout << "k=2^" << j;
        cout << "  -->  = 2^" << p;
        cout << " * " << f;
        if ( f==1 )  cout << " **** ";
        cout << endl;
    }

    cout << ".done." << endl;
    return 0;
}
//-------------------------------------------

// gp:
// j = 2^62-2^46+1
// factor(znorder(Mod(4,j)))
//

