
#include "mtypes.h"
#include "modm.h"
#include "jjassert.h"

//
// ARITHMETIC modulo some arbitrary modulus
// (cf. modarith.h)
//


umod_t
set_mod(umod_t x, umod_t m)
{
    if ( x>=m )  x %= m;
    return x;
}
// ============ end ===========


umod_t
incr_mod(umod_t a, umod_t m)
{
    a++;
    if ( a==m )  a = 0;
    return a;
}
// ============ end ===========


umod_t
decr_mod(umod_t a, umod_t m)
{
    if ( a==0 )  a = m-1;
    else         a--;
    return a;
}
// ============ end ===========



umod_t
sub_mod(umod_t a, umod_t b, umod_t m)
{
    if ( a>=b )  return  a-b;
    else         return  m-b+a;
}
// ============ end ===========



umod_t
add_mod(umod_t a, umod_t b, umod_t m)
{
    if ( 0==b )  return a;

//    return sub_mod(a,m-b,m);
    umod_t r = m-b;
    if ( a>=r )  return  a-r;
    else         return  m-r+a;
}
// ============ end ===========


#if  ( USE_64BIT_MODULUS )
#error  "no 64bit multiply yet"
// jjnote: todo: 64bit mod multiply
#else

//  Peter Montgomery: If 0 <= a, b < p < 2^31 and I want a modular product
//  a*b modulo p and the long long type is unavailable, then I can write
//
//  typedef   signed long slong;
//  typedef unsigned long ulong;
//  slong a, b, p, quot, rem;
//
//  quot = (slong) (0.5 + (double)a * (double)b / (double)p);
//  rem =  (slong)((ulong)a * (ulong)b - (ulong)p * (ulong)quot);
//  if (rem < 0} {rem += p; quot--;}
umod_t
mul_mod(umod_t a, umod_t b, umod_t m)
{
//    cout << " mul_mod(): a=" << a << "  b=" << b << "   m= " << m << endl;

    umod_t x = a * b;
//    ldouble m1dd = (ldouble)1/m;
//    umod_t y = m * (umod_t)((ldouble)a*(ldouble)b*(m1dd)+(ldouble)1/2);
    umod_t y = m * (umod_t)((ldouble)a*(ldouble)b/m+(ldouble)1/2);

    umod_t r = x - y;
//    if ( r&(1ULL<<63) )  r += m;
    if ( (smod_t)r < 0 )  r += m;

//    cout << " mul_mod(): returning " << r << endl;
//    cout << " mod( " << a << " * " << b << ", " << m << ") == " << r << endl;

    return  r;
}
// ============= end ===========
#endif  // USE_64BIT_MODULUS


umod_t
pow_mod(umod_t a, umod_t ex, umod_t m)
{
//    cout << " pow_mod(): a=" << a << "  ex=" << ex << "   m= " << m << endl;

    if ( 0==ex )  
    {
        return 1;
    }
    else
    {
        umod_t z = a;
        umod_t y = 1;

        while ( 1 )
        {
            if ( ex&1 )  y = mul_mod(y,z,m);  // y *= z;
            
            ex /= 2;

            if ( 0==ex )  break;

            z = mul_mod(z,z,m);  // z *= z;
        }

//        cout << " pow_mod():  result=" << y << endl;
        return y;
    }
}
// ============= end ===========


umod_t
inv_modp(umod_t a, umod_t p)
//
// compute inverse of a modulo p where p is prime 
//
{
    return pow_mod(a,p-2,p);
}
// ============= end ===========

umod_t
inv_modpp(umod_t a, umod_t p, long ex)
//
// compute inverse of a modulo p^ex where p is prime 
//
{
    umod_t v = p-1;
    if ( ex>1 )  v *= ipow(p,ex-1);

//    return pow_mod(a,v-1,p);
    return pow_mod(a,v-1,ipow(p,ex));
}
// ============= end ===========

/*
umod_t
inv_modpp_2(umod_t a, umod_t p, umod_t pp)
{
    return pow_mod(a,pp/p*(p-1)-1,pp);
}
// ============= end ===========
*/


#if  ( USE_64BIT_MODULUS )
#error " computation of mod inverse by egcd() doesn't work for moduli >63 bits "
#else
umod_t
inv_mod(umod_t x, umod_t m)
{
    smod_t u,v;
    smod_t d = egcd(m,x,u,v);    // d==m*u+x*v
    jjassert2( d==1, "impossible inversion: gcd(x,m)!=1" );
//    jjassert( d==m*u+x*v );

    if ( v<0 )  v += m;

//    cout << " inv_mod(): x=" << x << "  v=" << v << " x*v=" << mul_mod(x,v,m) << "\n";

    if ( mul_mod(x,v,m)!=1 )  jjassert2( 0,"inv_mod() failed" );

    return (umod_t)v;
}
// ============= end ===========
#endif // USE_64BIT_MODULUS

/*
umod_t
div_mod(umod_t a, umod_t b, umod_t m)
{
    return mul_mod(a,inv_mod(b,m),m);
}
// ============= end ===========
*/
