
#include "mod.h"
#include "modm.h"
#include "jjassert.h"
#include "factor.h"

/*
int
is_quadratic_residue(const factorization &af, umod_t m)
//
// whether af is quadratic residue mod m 
//
// use (a*b/c) == (a/c) * (b/c)
//
// UNTESTED
{
    int k = 1;

    for (int i=0; i<af.npr; ++i)
    {
        int s = kronecker(af.prime[i],m);
        if ( 1&(af.expon[i]) )  s = -s;
        k *= s;
    }

    return k;
}
//================= end ================
*/

int
is_quadratic_residue(umod_t a, const factorization &mf)
//
// whether a is quadratic residue mod mf 
//
{
    for (int i=0; i<mf.npr; ++i)
    {
        umod_t p = mf.prime[i];
        long x = mf.expon[i];
        if ( 2==p )
        {
            if ( x==1 )  continue;
            if ( (x>=3 ) && (1==(a&7)) )  continue;
            if ( (x==2 ) && (1==(a&3)) )  continue;
        
            return 0;
        }
        else if ( 1!=kronecker(a,p) )
        {
            return 0;
        }
    }

    return 1;
}
//================= end ================


umod_t
sqrt_modp(umod_t a, umod_t p)
//
// p an odd prime
// if a is not a quadratic residue mod p return 0 
// else return x so that x*x==a (mod p)
// cf. cohen p.33
//
{
    if ( 1!=kronecker(a,p) )  return 0;  // not a quadratic residue

    // initialize q,t so that  p == q * 2^t + 1
    umod_t q;
    int t;
    n2qt(p,q,t);

    // FIND GENERATOR:
    umod_t z = 0;
    umod_t n;
    for (n=1; n<p; ++n)
    {
        if ( -1==kronecker(n,p) )
        {
            z = pow_mod(n,q,p);
            goto found;
        }
    }
    
    jjassert2(0, " sqrt_modp(): no generator found ! " );

found:;
    // INITIALIZE:
    umod_t y = z;
    uint r = t;
    umod_t x = pow_mod(a,(q-1)/2,p);
    umod_t b;
    b = mul_mod(a,x,p);
    b = mul_mod(b,x,p);
    x = mul_mod(x,a,p);

step3:
    // FIND EXPONENT:
    jjassert( mul_mod(a,b,p)==mul_mod(x,x,p) );
    jjassert( r>0 );
    jjassert( pow_mod(y,1ULL<<(r-1),p)==p-1 );
    jjassert( pow_mod(b,1ULL<<(r-1),p)==1ULL );

    if ( 1==b )  return x;

    uint m;
    for (m=1; m<r; ++m)
    {
        if ( 1==pow_mod(b,1ULL<<m,p) )  break;
    }

    if ( m==r )
    {
//        cout << "sqrt_modp(a,p): a not a quadratic residue mod p ! \n";
        return 0;  // a is not a quadratic residue mod p
    }

    // REDUCE EXPONENT:
    umod_t v = pow_mod(y,1ULL<<(r-m-1),p);
    y = mul_mod(v,v,p);
    r = m;
    x = mul_mod(x,v,p);
    b = mul_mod(b,y,p);

    goto step3;
}
//================= end ================


umod_t
sqrt_modpp(umod_t a, umod_t p, long ex)
//
// return r with r^2 == a (mod p^ex)
//
{
    umod_t z = a%p;
    umod_t r;

    if ( 2==p )
    {
        if ( a&1 )  r = 1;
        else        r = 0;
    }
    else
    {
        r = sqrt_modp(z,p);
        if ( r==0 )  return 0;  // no sqrt exists
    }
    // here r^2 == a (mod p)

    if ( 1==ex )  return r;

    long x = 1;
    umod_t m = 0;
    umod_t h = 0;

    if ( 2==p )
    {
        jjassert2( 0, " sqrt_modpp(): case 2^n, n>1 not implemented " );
    }
    else
    {
        m = ipow(p,ex);
        h = inv_modpp(2,p,ex);
    }

    while ( x<ex )  // newton iteration
    {
        x *= 2;

        umod_t ri = inv_modpp(a,p,ex);    // 1/r
        umod_t ar = mul_mod(a,ri,m);      // a/r
        r = add_mod(r,ar,m);              // r+a/r
        r = mul_mod(r,h,m);               // (r+a/r)/2
    }

    return r;
}
//================= end ================


umod_t
sqrt_modf(umod_t a, const factorization &mf)
{
//    jjassert( is_quadratic_residue(a,mf) );

    umod_t x[mf.npr];  // residues mod p_i go here

    for (int i=0; i<mf.npr; ++i)
    {
        x[i] = sqrt_modpp(a, mf.prime[i], mf.expon[i]);
        if ( x[i]==0 )  return 0;  // no sqrt exists
    }

    return chinese(x,mf);
}
//================= end ================

