
#include <iostream.h>
#include "jjassert.h"

#include "modm.h"


umod_t
gcd(umod_t a, umod_t b)
{
    if ( a==0 )  return b;
    if ( b==0 )  return a;
    umod_t r;
    do
    {
        r = a % b;
        a = b;
        b = r;
    }
    while ( r!=0 );
    return a;
}
// ============= end ===========


smod_t
binary_gcd(smod_t a, smod_t b)
{
    smod_t t;
    if ( a < b )
    {
        t = a;
        a = b;
        b = t;
    }

    if ( b==0 )  return a;

    smod_t r;
    r = a % b;
    a = b;
    b = r;

    if ( b==0 )  return a;

    smod_t k = 0;
    while ( !(a&1) && !(b&1) )
    {
        k++;
        a >>= 1;
        b >>= 1;
    }

    while ( !(a&1) )  a >>= 1;

    while ( !(b&1) )  b >>= 1;

    while ( 1 )
    {
        t = (a-b) >> 1;

        if ( t==0 )  return  (1<<k) * a;

        while ( !(t&1) )  t >>= 1;

        if ( t>0 )  a = t;
        else        b = -t;
    }
}
// ============= end ===========


umod_t
lcm(umod_t a, umod_t b)
{
    return  a / gcd(a,b) * b;
}
// ============= end ===========



#define  EGCD_CHECK  0  // 0 (no check, default) or 1 (debug)
#if  ( EGCD_CHECK==1 )
#else
//#warning  "FYI: egcd is not checked "
#endif // EGCD_CHECK

#define  EGCD_PRINT  0  // 0 or 1 (for debug)


smod_t
egcd(smod_t u, smod_t v, smod_t &u1, smod_t &u2)
//
// return u3 and set u1,v1 so that
//   u3 == u*u1+v*u2
//
// cf. knuth2, p.325
{
    smod_t v1 = 0;
    smod_t v3 = v;

    u1 = 1;
    u2 = 0;
    smod_t u3 = u;

#if  ( EGCD_PRINT==1 )
    cout << "\n egcd():  u= " << u;
    cout << "\n egcd():  v= " << v << endl;
#endif

    smod_t v2 = 1;
    while ( v3!=0 )
    {
        smod_t q = u3/v3;

        smod_t t1 = u1-v1*q;
        u1 = v1;
        v1 = t1;

        smod_t t3 = u3-v3*q; // u3%v3
        u3 = v3;
        v3 = t3;

#if  ( EGCD_PRINT==1 )
        cout << "\n  q= " << q;
        cout << "\n  u1= " << u1;
        cout << "\n  u*u1= " << u*u1;
        cout << "\n  u3= " << u3 << endl;
        cout << "\n  u3-u*u1= " << u3-u*u1;
        cout << "\n  u2:=(u3-u*u1)/v= " << (u3-u*u1)/v << endl;
//        jjassert( u*u1+v*((u3-u*u1)/v)==u3 );
#endif

        smod_t t2 = u2-v2*q;
        u2 = v2;
        v2 = t2;

#if  ( EGCD_PRINT==1 )
        cout << "\n  u2= " << u2;
        cout << "\n  v*u2= " << v*u2;
#endif

#if  ( EGCD_CHECK==1 )
        jjassert( u*t1+v*t2==t3 );
        jjassert( u*u1+v*u2==u3 );
        jjassert( u*v1+v*v2==v3 );
#endif // ( EGCD_CHECK )
    }


#if  ( EGCD_CHECK==1 )
    jjassert( (umod_t)u3==gcd(u,v) );
//#else
//    jjassert( v == gcd(u3-u*u1,v) );
//    u2 = (u3-u*u1)/v;  // apparently does not work if u and v >= 2**63, why ?
#endif // ( EGCD_CHECK )

    jjassert( u*u1+v*u2==u3 );

    return u3;
}
// ============= end ===========

