
#include "modarith.h"

#include "jjassert.h"
#include <iostream.h>


#define  MUL_PRINT_LIKE_MAD  0  // 0 or 1

//inline
umod_t
mul_mod_mod_63bit(const umod_t &a, const umod_t &b)
//
// on output:
// quot == (a*b)/modulus
// rem  == (a*b)%modulus
//
{
//    jjassert( mod::m1d!=0.0 );

#if  ( MUL_PRINT_LIKE_MAD )
    cout << "\n MUL_MOD_MOD_63BIT(" << a << ", " << b << ")\n";
#endif


#if  ( MOD_PARANOIA )
    jjassert( a<MODULUS );
    jjassert( b<MODULUS );
#if  ( USE_LEQ_63BIT_MODULUS )
    jjassert( (smod_t)a>=0 );
    jjassert( (smod_t)b>=0 );
#endif
#endif


#if  ( MUL_PRINT_LIKE_MAD )
    ldouble ad=a, bd=b, &md=M1DD;
    umod_t m = MODULUS;
    cout.precision(21);
    cout << " a=" << "0x" << hex << a << dec << " =" << a << " =" << ad << endl;
    cout << " b=" << "0x" << hex << b << dec << " =" << b << " =" << bd << endl;
    cout << " m=" << "0x" << hex << m << dec << " =" << m << " 1/m=" << md << endl;
#endif

//    if ( (0!=ad) && (0!=bd) )
//    {
//        assert ( 0!=(ad*bd) );
//        assert ( 0!=(bd*md) );
//        assert ( 0!=(ad*bd*md) );
//    }


   // high bits (float mult):
    umod_t qt = (umod_t)((ldouble)a*(ldouble)b*(M1DD)+(ldouble)1/2 );
#if  ( MUL_PRINT_LIKE_MAD )
    cout << " qt=" << "0x" << hex << qt << dec << " =" << qt << endl;
#endif

#if  ( USE_LEQ_63BIT_MODULUS )
    jjassert( (smod_t)qt>=0 );
#endif


    // low bits (integer mult):
    umod_t ab = (umod_t)a * (umod_t)b;
    umod_t mq = (umod_t)MODULUS * (umod_t)qt;
#if  ( MUL_PRINT_LIKE_MAD )
    cout << " ab=" << "0x" << hex << ab << dec << " =" << ab << endl;
    cout << " mq=" << "0x" << hex << mq << dec << " =" << mq << endl;
#endif

    smod_t rem = (smod_t)(ab-mq);
    if ( rem<0 )  { rem += MODULUS; qt--; }  // ok (with smod_t rem)
    assert ( (smod_t)rem>=0 );

//    smod_t rem;  // doesn' work
//    if ( ab>=mq )  { rem = ab-mq; }
//    else           { rem = MODULUS+ab-mq; qt--; }

#if  ( MUL_PRINT_LIKE_MAD )
    cout << " rem=" << "0x" << hex << rem << dec << " =" << rem << endl;
#endif

#if  ( MOD_PARANOIA )
#if  ( USE_LEQ_63BIT_MODULUS )
#endif
    jjassert( (umod_t)qt<MODULUS );
    jjassert( (umod_t)rem<MODULUS );
#endif

//    cout << "\n mul_mod_mod_63bit(" << a << ", " << b << ") = " << (umod_t)rem << "\n";
//    cout << "\n x=";
//    cout << "mod(" << a << ", m)";
//    cout << "*";
//    cout << "mod(" << b << ", m)";
//    cout << "; if(0!=x-";
//    cout << "mod(" << (umod_t)rem << ", m)"; 
//    cout << ",1/0,0)";
//    cout << "\n";

    return  (umod_t)rem;
}
// ============ end MUL_MOD_MOD_63BIT ===========


//inline
umod_t
mul_mod_mod_64bit(umod_t a, umod_t b)
//
// use (2a+e)(2b+f) = 4ab+2fa+2eb+ef
// here a,b are 63bit and e,f are 0 or 1
//
// doesn't work yet !
{
    int  e = a&1;
    a >>= 1;        // 63 bit

    int  f = b&1;
    b >>= 1;        // 63 bit

    umod_t  ab = mul_mod_mod_63bit(a,b);

    ab = add_mod_mod(ab,ab);  // 2ab
    ab = add_mod_mod(ab,ab);  // 4ab

    if ( e )
    {
        a =  add_mod_mod(a,a);   // 2a
        ab = add_mod_mod(ab,a);  // 4ab+2a
    }

    if ( f )
    {
        b =  add_mod_mod(b,b);   // 2b
        ab = add_mod_mod(ab,b);  // 4ab+2a+2b

        if ( e )   ab = incr_mod_mod(ab); // 4ab+2a+2b+1
    }

    return ab;
}
// ============ end MUL_MOD_MOD_64BIT ===========


    /*
    const umod_t axxx = (umod_t)617673396283947;
    const umod_t bxxx = (umod_t)1853020188851841;
    const umod_t mxxx = (umod_t)922323129936642048;
    const umod_t cok  = (umod_t)301591476059826603;
    umod_t cxxx = mul_mod_mod_63bit(axxx,bxxx);

    cout << " a=" << "0x" << hex << axxx << dec << " =" << axxx << endl;
    cout << " b=" << "0x" << hex << bxxx << dec << " =" << bxxx << endl;
    cout << " m=" << "0x" << hex << mxxx << dec << " =" << mxxx << endl;
    cout << " c=" << "0x" << hex << cxxx << dec << " =" << cxxx << endl;
    cout << " d=" << "0x" << hex << cok  << dec << " =" << cok  << endl;
    jjassert( cxxx==cok );
    */


uint
nines(umod_t a, uint bits, uint mask)
{
//    cout << "\n nines(): " ;
    uint  ai = 0;

    do
    {
        ai +=  ((uint)a & mask);
        ai &= mask;
        a >>= bits;
//        cout << "0x" << hex << ai << dec << endl;
    }
    while ( a );

    return ai;
}
// --------------

void
nines_test(umod_t a, umod_t b, umod_t c)
{
    cout << "\n nines_test(): \n" ;

    uint  bits = 8; // < 8*sizeof(uint)/2
    uint  mask = (uint)(~0);
//    cout << " mask=" << "0x" << hex << mask << dec << endl;
    mask >>= (8*sizeof(uint)-bits);
    cout << " mask=" << "0x" << hex << mask << dec << endl;


    uint ai = nines(a,bits,mask);
    uint bi = nines(b,bits,mask);
    uint ci = nines(c,bits,mask);

    uint di = nines((umod_t)(ai*bi),bits,mask);

    cout << " ai=" << "0x" << hex << ai << dec << endl;
    cout << " bi=" << "0x" << hex << bi << dec << endl;
    cout << " ci=" << "0x" << hex << ci << dec << endl;
    cout << " di=" << "0x" << hex << di << dec << endl;

    jjassert( ci==di );
}
// --------------

