
#include <iostream.h>
#include <math.h>

#include "mtypes.h"
#include "factor.h"
#include "modm.h"
#include "jjassert.h"


const int factorization::maxprimes = 64;
// worst case: 2*3*5*7*11*...

void
factorization::ctor_core()
{
    prime   = new umod_t[maxprimes+1];
    exponent = new long[maxprimes+1];
    npr = 0;
}
// ----------------


factorization::factorization()
{
    ctor_core();
}
// ----------------


factorization::~factorization()
{
    delete [] prime;
    delete [] exponent;
}
// ----------------


factorization::factorization(umod_t n)
{
    ctor_core();
    make_factorization(n);
    sort();
}
// ----------------


factorization::factorization(umod_t n, umod_t *f)
//
// initialize from zero-terminated (complete!)
// list of factors of n
// setup exponents
//
{
    ctor_core();
    make_factorization(n,f);
    sort();
}
// ----------------

void
factorization::sort()
{
    static umod_t pr[maxprimes+1];
    static long   ex[maxprimes+1];

    for (int i=0; i<npr; ++i)
    {
        pr[i] = prime[i];
        ex[i] = exponent[i];
        prime[i] = 0;
        exponent[i] = 0;
    }


    for (int i=0; i<npr; ++i)
    {
//        cout << "\n i=" << i << ": " << *this;

        umod_t mi = (~0);
        int f=0;
        for (int j=0; j<npr; ++j)
        {
//            cout << "\n j=" << j << ": " << pr[j]; 

            if( pr[j]<mi )
            {
                mi = pr[j];
                f = j;
            }
        }
//        cout << "\n mi=" << mi << endl;

        prime[i] = pr[f];
        exponent[i] = ex[f];
        pr[f] = (~0);
    }
}
// ----------------

int
factorization::has_prime_factor(umod_t f) const
{
    for (int i=0; i<npr; ++i)
    {
        if ( f==prime[i] )  return 1;
    }

    return 0;
}
// ----------------


umod_t
factorization::factor(int i) const
{
    return ipow(prime[i],exponent[i]);
}
// ----------------


umod_t
factorization::product() const
{
    umod_t m = 1;
    for (int i=0; i<npr; ++i)
    {
        m *= factor(i);
    }

    return  m;
}
// ----------------


int
factorization::is_factorization_of(umod_t n) const
{
    return  n==(*this).product();
}
// ----------------

void
factorization::make_factorization(umod_t n, umod_t *f)
{
//    cout << "MAKE_FACTORIZATION(): \n";
//    for(int i=0; f[i]!=0; ++i)  cout << " f[" << i << "]=" << f[i] << endl;
    jjassert( f );

    for (int i=0; i<maxprimes; ++i)  prime[i]=exponent[i]=0;

    uint k = 0;
    while( 1 )
    {
        umod_t v = f[k];
//        cout << "\n n=" << n << "  v=" << v;  cout.flush();

        long ex = divide_out_factor(n,v);
//        cout << "  --> ex=" << ex << "  n=" << n << endl;
        jjassert( ex );

        prime[k] = v;
        exponent[k] = ex;
        k++;

        if ( 1==n )  break;
    }

    npr = k;

    jjassert( 1==n );

//    print("MAKE_FACTORIZATION(n,f): \n");
}
// ----------------

void
factorization::make_factorization(umod_t n)
{
    for (int i=0; i<maxprimes; ++i)  prime[i]=exponent[i]=0;

    if ( ::is_prime(n) )
    {
        prime[0] = n;
        exponent[0] = 1;
        npr = 1;
        return ;
    }

    umod_t maxv = (umod_t)(sqrt((double)n)+1);
    umod_t v;
    long ex;
    uint kp = 0;
    uint k = 0;
    while ( 1 )
    {
        v = ::prime(kp);
        if ( 0==v )  break;
        if ( maxv<v )  break;
        kp++;

        ex = divide_out_factor(n,v);

        if ( ex )
        {
            prime[k] = v;
            exponent[k] = ex;
            k++;

            maxv = (umod_t)(sqrt((double)n)+1);
        }
    }

    v = ::prime(k-1)+2;
    for( ; v<=maxv; v+=2)
    {
        ex = divide_out_factor(n,v);

        if ( ex )
        {
            prime[k] = v;
            exponent[k] = ex;
            k++;

            maxv = (umod_t)(sqrt((double)n)+1);
        }
    }

    if ( n!=1 )
    {
        prime[k] = n;
        exponent[k] = 1;
        k++;
    }

    npr = k;

//    print("MAKE_FACTORIZATION(n): \n");
}
// ----------------


void
factorization::print(const char *bla/*=NULL*/, ostream &os/*=cout*/) const
{
    if ( bla )  os << bla;

    int k;
    for (k=0; k<npr-1; ++k)
    {
	os << prime[k] ;
        if ( exponent[k]>1 )  os << "^" << exponent[k];
        os << "*";
    }

    if ( npr>0 )
    {
	os << prime[k] ;
        if ( exponent[k]>1 )  os << "^" << exponent[k];
    }

//    os << "  (npr=" << npr <<")";
//    os << endl;
}
// ----------------

