
#include <iostream.h>
#include <math.h>

#include "mtypes.h"
#include "factor.h"
#include "modm.h"
#include "jjassert.h"


#define  FACT_DEBUG  0  // 0 or 1 (for debug)
#if  ( FACT_DEBUG==1 )
#warning 'FYI:  FACT_DEBUG == 1'
#endif

const int factorization::maxprimes = 30;
// worst case: 2*3*5*7*11*...
// 16 primes are enough for 64 bit
// 27 primes are enough for 128 bit

void
factorization::ctor_core()
{
    prime = new umod_t[maxprimes+1];
    expon = new long[maxprimes+1];
    fact = new umod_t[maxprimes+1];
    reset();
}
// ----------------


factorization::factorization()
    : prime(0), expon(0), npr(0), fact(0), prod(0)
{
    ctor_core();
}
// ----------------


factorization::factorization(umod_t n)
    : prime(0), expon(0), npr(0), fact(0), prod(0)
{
    ctor_core();
    make_factorization(n);
    sort();
}
// ----------------


factorization::factorization(umod_t n, umod_t *f)
    : prime(0), expon(0), npr(0), fact(0), prod(0)
//
// initialize from zero-terminated (complete!)
// list of factors of n
// setup exponents
//
{
    ctor_core();
    make_factorization(n,f);
    sort();
}
// ----------------


factorization::~factorization()
{
    delete [] prime;
    delete [] expon;
    delete [] fact;
}
// ----------------


void
factorization::reset()
{
    for (int k=0; k<maxprimes; k++)  prime[k]=expon[k]=fact[k]=0;
    npr = 0;
    prod = 0;
}
// ----------------


void
factorization::make_factorization(umod_t n, umod_t *f)
{
//    cout << "MAKE_FACTORIZATION(): \n";
//    cout << " n=" << n << endl;
//    if  ( f ) for (int i=0; f[i]!=0; ++i)  cout << " f[" << i << "]=" << f[i] << endl;

    if ( 0==f )
    {
        make_factorization(n);
        return;
    }

    reset();
    prod = n;    

    uint k = 0;
    while ( 1 )
    {
        umod_t v = f[k];
//        cout << "\n n=" << n << "  v=" << v;  cout.flush();

        long ex = divide_out_factor(n,v);
//        cout << "  --> ex=" << ex << "  n=" << n << endl;
        if ( 0==ex )
        {
            cerr << " n=" << n << "  v=" << v << endl;
            cerr << " v does not divide n " << endl;
            jjassert2( 0, "got false factor v for n" );
        }
        
        prime[k] = v;
        expon[k] = ex;
        fact[k] = ipow(v,ex);
        k++;

        if ( 1==n )  break;
    }

    npr = k;

    jjassert( 1==n );

#if  ( FACT_DEBUG==1 )
    check();
#endif
//    print("MAKE_FACTORIZATION(n,f): \n");
}
// ----------------

void
factorization::make_factorization(umod_t n)
{
    reset();
    prod = n;
    
    if ( ::is_pseudoprime(n,100) )
    {
        prime[0] = n;
        expon[0] = 1;
        fact[0] = n;
        npr = 1;
        return;
    }

    umod_t maxv = (umod_t)(sqrt((double)n)+1);
    umod_t v;
    long ex;
    uint kp = 0;
    uint k = 0;
    while ( 1 )
    {
        v = ::prime(kp);
        if ( 0==v )  break;
        if ( maxv<v )  break;
        kp++;

        ex = divide_out_factor(n,v);

        if ( ex )
        {
            prime[k] = v;
            expon[k] = ex;
            fact[k] = ipow(v,ex);
            k++;

            maxv = (umod_t)(sqrt((double)n)+1);
        }
    }

    long pw = 1;
    if ( n>((umod_t)1<<31) )  // check whether n is a perfect square
    {
        umod_t w = int_sqrt(n);
        if ( w*w==n )
        {
            n = w;
            pw = 2;
            maxv = (umod_t)(sqrt((double)n)+1);
        }
    }

    v = ::prime(kp-1)+2;
    for ( ; v<=maxv; v+=2)
    {
        ex = divide_out_factor(n,v);

        if ( ex )
        {
            ex *= pw;
            prime[k] = v;
            expon[k] = ex;
            fact[k] = ipow(v,ex);
            k++;

            maxv = (umod_t)(sqrt((double)n)+1);
        }
    }

    if ( n!=1 )
    {
        prime[k] = n;
        expon[k] = pw;
        fact[k] = n;
        k++;
    }

    npr = k;

#if  ( FACT_DEBUG==1 )
    check();
#endif
//    print("MAKE_FACTORIZATION(n): \n");
}
// ----------------


void
factorization::sort()
{
    static umod_t pr[maxprimes+1];
    static long   ex[maxprimes+1];
    static umod_t fc[maxprimes+1];

    for (int i=0; i<npr; ++i)
    {
        pr[i] = prime[i];
        ex[i] = expon[i];
        fc[i] = fact[i];
        prime[i] = 0;
        expon[i] = 0;
        fact[i] = 0;
    }


    for (int i=0; i<npr; ++i)
    {
//        cout << "\n i=" << i << ": " << *this;

        umod_t mi = (umod_t)(~0);
        int f=0;
        for (int j=0; j<npr; ++j)
        {
//            cout << "\n j=" << j << ": " << pr[j]; 

            if ( pr[j]<mi )
            {
                mi = pr[j];
                f = j;
            }
        }
//        cout << "\n mi=" << mi << endl;

        prime[i] = pr[f];
        expon[i] = ex[f];
        fact[i] = fc[f];
        pr[f] = (umod_t)(~0);
    }

#if  ( FACT_DEBUG==1 )
    check();
#endif
}
// ----------------

int
factorization::exponent(umod_t f) const
{
    for (int i=0; i<npr; ++i)  if ( f==prime[i] )  return expon[i];
    return 0;
}
// ----------------


umod_t
factorization::factor(int i) const
{
    return fact[i];
}
// ----------------


umod_t
factorization::product() const
{
    return prod;
}
// ----------------


int
factorization::is_factorization_of(umod_t n) const
{
    return  n==(*this).product();    
}
// ----------------


void
factorization::print(const char *bla, ostream &os) const
{
    if ( bla )  os << bla;

    int k;
    for (k=0; k<npr-1; ++k)
    {
	os << prime[k] ;
        if ( expon[k]>1 )  os << "^" << expon[k];
        os << "*";
    }

    if ( npr>0 )
    {
	os << prime[k] ;
        if ( expon[k]>1 )  os << "^" << expon[k];
    }

//    os << "  (npr=" << npr <<")";
//    os << endl;
}
// ----------------


//istream&  operator >> (istream& is, factorization& h);
//{ is >> ; return is; }


ostream&  operator << (ostream& os, const factorization& h)
{
    h.print(0,os);
    return os;
}
// ----------------


void
factorization::check()
{
    umod_t m = 1;
    for (int i=0; i<npr; ++i)
    {
//        cout << "fact[" << i << "] = " << prime[i] << endl;
        jjassert( ::is_pseudoprime(prime[i],100) );
        m *= fact[i];
    }

    jjassert( m == prod );
}
// ----------------


// auxiliary functions:

umod_t
is_factor(umod_t n, umod_t f)
//
// if f divides n  return n/f
// else            return 0
//
{
    umod_t q = n/f;

    if ( q*f==n )  return q;
    else           return 0;
}
// ----------------

long
divide_out_factor(umod_t &n, umod_t v)
//
// while v divides  n
//   divide n by v
// return how often divided
//
{
    umod_t q = is_factor(n,v);

    if ( q==0 )  return 0;
    else
    {
        long ex = 0;
        do
        {
            n = q;
            ex++;

            q = is_factor(n,v);
        }
        while ( q!=0 );

        return ex;
    }
}
// ----------------


umod_t
int_sqrt(umod_t d)
// cf. Cohen p.38
{
    umod_t x = (umod_t)ceil(sqrt((double)d));
    umod_t y = (x + d/x)/2;
    while ( y<x )
    {
        x = y;
        y = (x + d/x)/2;
    }

//    cout << "int_sqrt():  d=" << d << "  x=" << x << endl; 
    return x;
}
// ----------------
