
#include "mod.h"
#include "modm.h"
#include "jjassert.h"
#include "factor.h"

int
is_quadratic_residue(umod_t a, const factorization &mf)
//
// whether a is quadratic residue mod mf 
//
{
    for (int i=0; i<mf.npr; ++i)
    {
        umod_t p = mf.prime[i];
        long x = mf.exponent[i];
        if ( 2==p )
        {
            if ( x==1 )  break;
            if ( (x>=3 ) && (1==(a&7)) )  break;
            if ( (x==2 ) && (1==(a&3)) )  break;
        
//            cout << " sqrt_modf(): "
//                 << "not a quadratic residue for 2^x ! \n";
            return 0;
        }
        else if ( 1!=kronecker(a,mf.prime[i]) )
        {
//            cout << "sqrt_modf(): "
//                 << a << " is not a quadratic residue "
//                 << "modulo the prime factor " 
//                 << mf.prime[i] << " ! \n";
            return 0;
        }
    }

    return 1;
}
//================= end IS_QUADRATIC_RESIDUE ================


int
is_quadratic_residue(const factorization &af, umod_t m)
//
// whether af is quadratic residue mod m 
//
// use (a*b/c) == (a/c) * (b/c)
//
// UNTESTED
{
    int k = 1;

    for (int i=0; i<af.npr; ++i)
    {
        int s = kronecker(af.prime[i],m);
        if ( 1&(af.exponent[i]) )  s = -s;
        k *= s;
    }

    return k;
}
//================= end IS_QUADRATIC_RESIDUE ================


umod_t
sqrt_modpp(umod_t a, umod_t p, long ex)
//
// return r with r^2 == a (mod p^ex)
//
{
    umod_t z = a%p;
    umod_t r;

    if ( 2==p )
    {
        if ( a&1 )  r = 1;
        else        r = 0;
    }
    else
    {
        r = sqrt_modp(z,p);
    }
    // here r^2 == a (mod p)

    if ( 1==ex )  return r;

    long x = 1;
    umod_t m;
    umod_t h;

    if ( 2==p )
    {
        jjassert( 0*(int)" sqrt_modpp(): case 2^n, n>1 not implemented " );
    }
    else
    {
        m = ipow(p,ex);
        h = inv_mod(2,m);
    }

    while ( x<ex )  // newton iteration
    {
        x *= 2;

//        cout << " a=" << a << "\n";
//        cout << " r_o=" << r << "\n";
//        cout << " a/r=" << div_mod(a,r,m) << "\n";
//        cout << " r+a/r=" << add_mod(r,div_mod(a,r,m),m) << "\n";

        r = add_mod(r,div_mod(a,r,m),m);  // r+a/r
        r = mul_mod(r,h,m);               // (r+a/r)/2
//        cout << " r_n=" << r << "\n";
//        cout << " m=" << m << "\n";

//        umod_t s = mul_mod(r,r,m);
//        cout << " s=" << s << "\n";
//        jjassert( (a%m)==s );
    }

    return r;
}
//================= end SQRT_MODPP ================


umod_t
sqrt_modf(umod_t a, const factorization &mf)
{
//    jjassert( is_quadratic_residue(a,mf) );

    umod_t x[mf.npr];  // residues mod p_i go here

    for (int i=0; i<mf.npr; ++i)
    {
        x[i] = sqrt_modpp(a,mf.prime[i],mf.exponent[i]);
    }

    return chinese(x,mf);
}
//================= end SQRT_MODF ================


umod_t
sqrt_modp(umod_t a, umod_t p)
//
// p an odd prime
// if a is not a quadratic residue mod p return 0 
// else return x so that x*x==a (mod p)
// cf. cohen p.33
//
{
//    cout << " SM: sqrt_mod(a=" << a << ", p=" << p << ")\n";

//    jjassert( 1==kronecker(a,p) );

    uint e = 0;
    umod_t q = p-1;

    while ( 0==(q&1) )
    {
        e++;
        q >>= 1;
    }
    // p == 2^e*q
    jjassert( (p-1)==(q*(1ULL<<e)) );
//    cout << " SM: p == q=" << q << " *2^ " << e << " =e \n";


    // find generator:
//    cout << " SM: step1: \n";
    umod_t z = 0;
    umod_t n;
    for (n=1; n<p; ++n)
    {
        if ( -1==kronecker(n,p) )
        {
            z = pow_mod(n,q,p);
            goto found;
        }
    }
    
    jjassert( 0*(int)" sqrt_modp(): no generator found ! " );

found:;

//    cout << " SM: n=" << n << " \n";
//    cout << " SM: generator z=" << z << " == n^q (mod p)\n";


    // initialize:
//    cout << " SM: step2: \n";
    umod_t y = z;
//    cout << " SM: y=" << y << " == z (mod p)\n";
    uint r = e;
//    cout << " SM: r=" << r << " \n";
    umod_t x = pow_mod(a,(q-1)/2,p);
//    cout << " SM: x=" << x << " == a^((q-1)/2) (mod p)\n";
    umod_t b;
    b = mul_mod(a,x,p);
    b = mul_mod(b,x,p);
//    cout << " SM: b=" << b << " == a*x^2 (mod p)\n";

    x = mul_mod(x,a,p);
//    cout << " SM: x=" << x << " == a*x (mod p)\n";

step3:
    // find exponent:
//    cout << " SM: step3: \n";

    jjassert( mul_mod(a,b,p)==mul_mod(x,x,p) );
    jjassert( r>0 );
    jjassert( pow_mod(y,1ULL<<(r-1),p)==p-1 );
    jjassert( pow_mod(b,1ULL<<(r-1),p)==1ULL );

    if ( 1==b )  return x;

    uint m;
    for (m=1; m<r; ++m)
    {
        if ( 1==pow_mod(b,1ULL<<m,p) )  break;
    }
//    cout << " SM: m=" << m << "  --> b^(2^m)==1 \n";

    if ( m==r )
    {
        cout << "sqrt_modp(a,p): a not a quadratic residue mod p ! \n";
        return 0;  // a is not a quadratic residue mod p
    }

    // reduce exponent:
//    cout << " SM: y=" << y << "\n";
//    cout << " SM: r=" << r << "\n";
//    cout << " SM: m=" << m << "\n";
//    cout << " SM: r-m-1=" << r-m-1 << "\n";
    umod_t t = pow_mod(y,1ULL<<(r-m-1),p);
//    cout << " SM: t=" << t << " == y^(2^(r-m-1)) (mod p)\n";

    y = mul_mod(t,t,p);
//    cout << " SM: y=" << y << " == t*t (mod p)\n";
    r = m;
//    cout << " SM: r=" << r << " == m (mod p)\n";
    x = mul_mod(x,t,p);
//    cout << " SM: x=" << x << " == t*p (mod p)\n";
    b = mul_mod(b,y,p);
//    cout << " SM: b=" << x << " == y*p (mod p)\n";

    goto step3;
}
//================= end SQRT_MODP ================

