
#include "mtypes.h"
#include "modm.h"
#include "jjassert.h"

//
// ARITHMETIC modulo some arbitrary modulus
// (cf. modarith.h)
//


umod_t
u_set_mod(umod_t x, const umod_t &m)
{
    if ( x>=m )  x %= m;
    return x;
}
// ============ end U_SET_MOD ===========


umod_t
incr_mod(umod_t a, const umod_t &m)
{
    a++;
    if ( a==m )  a = 0;
    return a;
}
// ============ end INCR_MOD ===========


umod_t
decr_mod(umod_t a, const umod_t &m)
{
    if ( a==0 )  a = m-1;
    else         a--;
    return a;
}
// ============ end DECR_MOD ===========



umod_t
add_mod(const umod_t &a, const umod_t &b, const umod_t &m)
{
    umod_t s = a+b;

#if  ( ! USE_LEQ_63BIT_MODULUS )
#error " add_mod() might fail. "
    // must subtract m if:
    // 1.  (sum<summand1) AND (sum<summand2), i.e. result 'wrapped around'
    if ( (s<a)&&(s<b) )   return s-m;
#endif

    // 2.  sum >= m
    if ( s>=m )   return  s-m;

    return s;
}
// ============ end ADD_MOD_MOD ===========


umod_t
sub_mod_mod(const umod_t &a, const umod_t &b, const umod_t &m)
{
    if ( a>=b )  return  a-b;
    else         return  m-b+a;
}
// ============ end SUB_MOD_MOD ===========


#if  ( ! USE_LEQ_63BIT_MODULUS )
#error  "no 64bit multiply yet"
#endif

umod_t
mul_mod(const umod_t &a, const umod_t &b, const umod_t &m)
{
    ldouble m1dd = (ldouble)1/m;
    umod_t qt = (umod_t)((ldouble)a*(ldouble)b*(m1dd)+(ldouble)1/2 );
    umod_t ab = (umod_t)a * (umod_t)b;
    umod_t mq = (umod_t)m * (umod_t)qt;
    smod_t rem = (smod_t)(ab-mq);
    if ( rem<0 )  { rem += m; qt--; }
    return  (umod_t)rem;
}
// ============= end MUL_MOD ===========


umod_t
pow_mod(umod_t a, umod_t ex, const umod_t &m)
{
//    cout << " pow_mod(): a=" << a << "  ex=" << ex << "   m= " << m << " \n";

    if ( 0==ex )  
    {
        return 1;
    }
    else
    {
        umod_t z = a;
        umod_t y = 1;

        while ( 1 )
        {
            if ( ex&1 )  y = mul_mod(y,z,m);  // y *= z;
            
            ex /= 2;

            if ( 0==ex )  break;

            z = mul_mod(z,z,m);  // z *= z;
        }

        return y;
    }
}
// ============= end POW_MOD ===========


umod_t
inv_modp(const umod_t &a, const umod_t &p)
{
    return pow_mod(a,p-2,p);
}
// ============= end INV_MODP ===========


umod_t
inv_mod(const umod_t &x, const umod_t &m)
{
#if  ( ! USE_LEQ_63BIT_MODULUS )
#error " inv_mod(): egcd_core() will fail. "
#endif

    smod_t u,v;
    smod_t d = egcd(m,x,u,v);    // d==m*u+x*v
    jjassert( d==1 );      // inversion possible <==> gcd(x,m)==1
//    jjassert( d==m*u+x*v );

    if ( v<0 )  v = m+v;

//    cout << " inv_mod(): x=" << x << "  v=" << v << " x*v=" << mul_mod(x,v,m) << "\n";

    if ( mul_mod(x,v,m)!=1 )
    {
	jjassert( 0*(int)"inv_mod() failed" );
    }

    return (umod_t)v;
}
// ============= end INV_MOD ===========


umod_t
div_mod(const umod_t &a, const umod_t &b, const umod_t &m)
{
    return mul_mod(a,inv_mod(b,m),m);
}
// ============= end DIV_MOD ===========
