
#include <iostream.h>
#include <assert.h>

#include "mod.h"
#include "modarith.h"

// debug:
#define MOD_ASSERT(z)    //assert( !((z).check()) )


//---------------------------

mod::mod()
{
}

#if defined USE_64BIT_MOD_T
mod::mod(const ulong i)
{
    (*this) = i;
}

mod::mod(const long i)
{
    (*this) = i;
}

mod & 
mod::operator = (const long i)
{
    (*this) = (smod_t)i;
    return *this;
}

mod & 
mod::operator = (const ulong i)
{
    (*this) = (umod_t)i;
    return *this;
}

//void 
//mod::set_x(const ulong z)
//{
//    (*this).set_x((umod_t)z);  // without check !
//}

#else

#endif // defined USE_64BIT_MOD_T


void 
mod::set_x(const umod_t z)
{
    (*this).x = (mod_t)z;  // without check !
}


mod::mod(const umod_t i)
{
    (*this) = i;
}

mod::mod(const smod_t i)
{
    (*this) = i;
}

mod::mod(const mod &m)
{
    x = m.x;
}

mod::~mod()
{
}


mod_t
mod::get_x() const
{
    return (*this).x;
}

mod & 
mod::operator = (const mod &h)
{
    x = h.x;
    return *this;
}

mod & 
mod::operator = (const smod_t i)
{
    x = s_mod_mod_while(i);
    return *this;
}

mod & 
mod::operator = (const umod_t i)
{
    x = u_mod_mod_while(i);
    return *this;
}


mod 
operator + (const mod &h)
{
    mod p;
    p.x = h.x;
    return p;
}

mod 
operator - (const mod &h)
{
    mod n;
    n.x = (mod::modulus)-h.x;
    return n;
}


mod & 
mod::operator += (const mod &h)
{
#ifdef USE_UNSIGNED_MOD_T
    x = u_add_mod_m(x,h.x,modulus);
#else
    x = s_add_mod_m(x,h.x,modulus);
#endif

    MOD_ASSERT(*this);
    return *this;
}

mod &
mod::operator -= (const mod &h)
{
#ifdef USE_UNSIGNED_MOD_T
    x = u_sub_mod_m(x,h.x,modulus);
#else
    x = s_sub_mod_m(x,h.x,modulus);
#endif

    MOD_ASSERT(*this);
    return *this;
}

mod & 
mod::operator -= (const mod_t i)
{
    *this -= mod(i);
    return *this;
}

mod & 
mod::operator *= (const mod &h)
{
#ifdef USE_64BIT_MOD_T
#ifdef USE_UNSIGNED_MOD_T
#error "must use unsigned mod_t for 64 bit !"
#endif
     x = s_mul_mod_mod(x,h.x);   
#else
     x = u_mul_mod_mod32(x,h.x);
#endif

    MOD_ASSERT(*this);
    return *this;
}


mod & 
mod::sqr() 
{
    (*this) *= (*this);
    return *this;
}

int 
mod::check() const
{
#ifndef USE_UNSIGNED_MOD_T
    if( x<0 )  return -1; 
#endif

    if( x>=modulus )  return 1;    

    return 0;
}

int 
operator == (const mod &h1, const mod &h2)
{
    return h1.x==h2.x;
}

int 
operator != (const mod &h1, const mod &h2)
{
    return h1.x!=h2.x;
}

mod 
operator + (const mod_t i, const mod &h)  
{
    mod h3(i);
    h3 += h;
    return h3;    
}

mod 
operator + (const mod &h, const mod_t i)  
{
    mod h3(i);
    h3 += h;
    return h3;
}

mod 
operator + (const mod &h1, const mod &h2)  
{
    mod h3(h1);
    h3 += h2;
    return h3;
}

mod 
operator - (const mod &h1, const mod &h2)  
{
    mod h3(h1);
    h3 -= h2;
    return h3;
}

mod 
operator - (const mod_t i, const mod &h)  
{
    mod h3(i);
    h3 -= h;
    return h3;
}

mod 
operator - (const mod &h, const mod_t i)  
{
    mod h3(h);
    h3 -= i;
    return h3;
}

mod 
operator * (const mod &h1, const mod &h2)  
{
    mod h3(h1);
    h3 *= h2;
    return h3;
}

mod 
operator * (const mod &h, const mod_t i)  
{
    mod h3(i);
    h3 *= h;
    return h3;
}

mod 
operator * (const mod_t i, const mod &h)  
{
    mod h3(i);
    h3 *= h;
    return h3;
}


istream & 
operator >> (istream& is, mod& h)
{
    is>>h.x;
    return is;
}

ostream &  
operator << (ostream& os, const mod& h)
{
    os<<h.x;
    return os;
}
