
#include <assert.h>
#include <iostream.h>

#include "mod.h"

#define BITS (8*sizeof(umod_t))

#define M_ASSERT(x) { assert( x<(mod::modulus) ); }


umod_t
montgomery_r(ulong bits, umod_t m)
//
// 2^bits mod modulus
//
{
    if ( bits==BITS )  return (umod_t) 0-m;

    umod_t r=(umod_t)1;

    for(ulong i=0; i<bits; ++i)
    {
	r *= 2;

	if( r>=m )  r-=m;
    }

    return r;
}
//------------------------


umod_t
montgomery_r_prime(ulong bits, umod_t modulus)
//
// (2^bits)^-1 mod modulus
//
{
    umod_t r = montgomery_r(bits,modulus);

    if( r>=(mod::modulus) )  r -= (mod::modulus);

    mod rm=inv((umod_t)r);

    return (umod_t)(rm.get_x());
}
//------------------------


umod_t
montgomery_m_prime(ulong bits, umod_t m, umod_t mask)
//
//    -(modulus)^-1           mod (2^bits)
// == -(modulus)^phi(2^bits)  mod (2^bits)
// == -(modulus)^(2^(bits-1)) mod (2^bits)
// == (2^bits)-(modulus)^(2^(bits-1)) mod (2^bits)
//
{  // XXX fails !

    umod_t z = m;

    for(ulong i=0; i<bits-1; ++i)   // compute m^(2^(bits-1)) mod (2^bits)
    {
	z *= z;   
	z = ( z & mask); // mod (2^bits)
	cout<<"fds! z="<<z<<"  mask="<<mask<<endl;
    }

    z = mask+1-z;

    cout<<"fds! z="<<z<<"  m="<<m<<endl;
    cout<<"fds! -z*m="<<z<<"  m="<<m<<endl;

    assert( (umod_t)1==(m-z)*z );

    return z;
}
//------------------------

/*

umod_t 
to_mg(mod_t x, ulong bits, umod_t modulus)
//
// return (x*2^bits) mod modulus
// == x*R mod modulus
//
{
    M_ASSERT(x);

    umod_t y=(umod_t)x;
    umod_t m=modulus;

    for(ulong i=0; i<bits; ++i)
    {
	y >>= 1;

	if( y>=m )  y-=m;
    }

    return y;
}
//-------------------------


umod_t 
from_mg(mod_t x, umod_t modulus)
//
// return (x*((2^mg_bits)^-1)) mod (modulus)
// == x*R' mod (modulus)
//
{
    uint64 rp=(uint64)(mod::mg_r_prime);

    uint64 t64 = (rp*(uint64)x);

    t64 %= (uint64)modulus;
    
    return (mod_t)t64;
}
//-------------------------


umod_t
redc(umod_t x)
{
    umod_t s = x*(mod::mg_m_prime);

    uint64 t64 = ((uint64)x+(uint64)s*(uint64)(mod::modulus));

    assert( (t64 & (mod::mg_mask))==0 );

    umod_t t = (t64>>(mod::mg_bits));

    if( t<(mod::modulus) )  return t;
    else                    return (mod::modulus)-t;
}
//------------------------


umod_t
mg_mult(umod_t x, umod_t y)
{
    return redc( x*y );
}
//------------------------

*/
