
#include <math.h>
#include <iostream.h>
#include <stdlib.h>

#include "hfloat.h"
#include "hfverbosity.h"
#include "hfdatafu.h"  // dt_mantissa_to_double()
#include "inline.h"  // min(), sign()
#include "auxbit.h"  // next_pow_of_2()

#include "jjassert.h"
#include "hfparanoia.h"


// for debug:
#define PR(x)
#define PRP  10


int
iroot_iteration(const hfloat &dd, ulong ri, hfloat &x, ulong startprec/*=0*/)
//
//   the equation
//   1/d^(1/r)== x*(1+y/r+y^2*(1+r)/(2! r^2)+y^3*(1+r)(1+2r)/(3! r^3)+...)
//            == x * \sum_{k=0}^{\infty}{y^n*(1/r)_n}
//   (where y:=(1-d*x^r)  and the subscript n denotes the
//    pochhammer symbol (rising factorial pow) )
//
//   gives these iteration: for 1/d^(1/r):
//   2nd order: x += x*(y/r)
//   3rd order: x += x*(y/r * [1+y*(1+r)/2/r])
//
//  returns the number of iterations needed
{
#ifdef  HF_PARANOIA
    jjassert( dd.OK() );
#endif

    startprec = 0;  // jjnote: iroot_iteration() FAILS with startprec!=0

//    assert( dd.data() != x.data() );

    static hfloat d(64);   // jjnote: static hfloat
    d.prec( dd.prec() );  // possible resize
    d = dd;
#define  dd  GONE

//    if ( startprec==0 )  null(x.data()->dig(), x.prec());

    static hfloat tt(64);  // jjnote: static hfloat for third order term

    if ( hfverbosity::itbegin )  cout << "\nIT_IROOT{" << ri << ":\n";

    const ulong dp0 = d.prec();

    const ulong pg = x.prec();  // precision goal
    const ulong rpg = (pg>4?pg-4:1);     // jjnote: magic: realistic precision goal

    long  r = (long)ri;

    PR( print("\n iroot: x=\n",x,64); );
    PR( cout << "startprec=" << startprec << endl; );

    int j = 0;
    if ( startprec>=rpg )  goto done;  // nothing to do

    { // computation:
        static hfloat t(pg);  // jjnote: temporary hfloat

        long  xe = (d.exp()/r)*r;
//        xe = 0;
        PR( cout << "\n  d.exp()=" << d.exp() << "  xe=" << xe << endl; );
        d.exp( d.exp()-xe );  // remove (big) exponent for calculations
        x.exp( x.exp()+xe/r );  // exponent of result for computations

        int  s = d.sign();
        d.sign(+1);

        ulong ap;  // achieved precision
        if ( startprec!=0 )  ap = startprec;
        else
        {
            approx_pow(d, x, -r);            
            ap = 3;  // assume 3 limbs are correct
            if ( ap>=rpg )  goto done;
        }
        
        ulong pr;  // calculation precision
        pr = next_pow_of_2( ap );
        pr = min(pr,pg);

        const int maxit = (int)(8+log((double)pg)/log(2.0));
        for (j=1; j<=maxit; ++j) // ------------ ITERATION ------------------
        {
            if ( 2*ap >= pr )  pr = min(2*pr,pg);
//            ulong hp = pr/2;
            t.prec(pr);
            x.prec(pr);
            d.prec( min(pr,dp0) );

            PR( print("\n ------  x= \n",x,PRP); );

//            if ( hp>16 && r==2 )
//            {
//                null(x.data()->dig()+hp, x.size()-hp);
//                null(t.data()->dig()+hp, t.size()-hp);
//                t.prec(hp);  x.prec(hp);
//                pow(x,r,t);
//                t.prec(pr);  x.prec(pr);
//                PR( print("\n  x^r= \n",t,PRP); );
//            }
//            else
            pow(x,r,t);

            mul(t,d,t);
            PR( print("\n  d*x^r= \n",t,PRP); );
            sub(1,t,t);
            PR( print("\n 1-d*x^r= \n",t,PRP); );
            div(t,r,t);
            PR( print("\n y/r:=(1-d*x^r)/r= \n",t,PRP); );

            ap += -t.exp();

            if ( (pr==pg) || (ap+63<=pr) )  // ===== THIRD ORDER TERM
            {
                // here:  t = y/r
                //  t += t*t*(1+r)/2 == (y/r)^2*(1+r)/2 == third order term

                if ( ap>=pr-4 )  // jjnote: magic: realistic loss of precision
                {
                    mul(t,x,t);
                    goto dort;
                }
                
                ulong ttp = next_pow_of_2(pr-ap);
                if ( hfverbosity::itprec )  cout << " (3rd: " << ttp << ") ";
                tt.prec( ttp );
                sqr(t, ttp, tt, ttp);

                if ( r!=1 )
                {
                    if ( (r&1)==0 )  // r even: 
                    {
                        mul(tt, (1+r), tt);
                        div2(tt,tt);
                    }
                    else          // r odd:
                    {
                        mul(tt, (1+r)/2, tt);
                    }
                }

                add(t, tt, t);
                mul(t, x, t);   // jjnote: full prec mult

                ap += ttp;
            }                        // ===== end THIRD ORDER TERM
            else  // half prec mult
            {
                mul(t, t.prec()/2, x, x.prec()/2, t, t.prec());
            }
        dort:

            PR( print("\n add= \n",t,PRP); );
            add(x, t, x);

            if ( ap>pr )  ap = pr;
            if ( hfverbosity::itprec )  cout << " (" << ap << "," << pr << ")\n";

            if ( ap>=rpg )  break;
            jjassert( ap>pr/4 );
        }
        // ---------------- end of ITERATION ---------------------

        if ( j>maxit )  jjassert2( 0, "iroot_iteration(): no convergence" );


        if ( hfloat::check_itiroot_result )
        {
            pow(x, r, t);
            mul(t, d, t);
            sub(1, t, t);
            PR( print("\n check_result: 1-d*x^r= \n",t,PRP); );
            ap = -t.exp();
            jjassert( ap>=rpg );
        }

        x.exp( x.exp()-xe/r );  // exponent of result
        x.sign( s );
    }
    
done:
    if ( hfverbosity::itbegin )  cout << "}\n";

    return j;

#ifdef  HF_PARANOIA
    jjassert( d.OK() );
    jjassert( x.OK() );
#endif
}
// =================== end ================



#define PRA(x)

#define  SIMPLE_POW  0  // 0 or 1;  default is 0

void
approx_pow(const hfloat &d, hfloat &c, long p)
//
// c = d^(-1/p)
//
// to avoid overflow for big arguments use:
// set c := M*R^X
// then d = M^(1/p) * R^(1/p*X)
//  == M^(1/p) * R^(mod(X,p)/p) * R^(div(X,p))
//
{

#ifdef  HF_PARANOIA
    jjassert( d.OK() );
#endif

    double dd;
#if ( SIMPLE_POW==1 )
#warning 'FYI: SIMPLE_POW==1 in approx_pow()'
    hfloat2d(d, dd);
    dd = pow(dd, p);
    d2hfloat(dd, c);
#else // ( SIMPLE_POW )
    dt_mantissa_to_double(*(d.data()), dd);

    dd = pow(dd,1.0/(double)p);  // M^(1/p)

    long  ex = d.exp()/p;     // ex = d.exp()/p
    long  em = d.exp()-p*ex;  // em = d.exp()%p

    double tt = pow((double)d.radix(),(double)em/p);  // R^(mod(X,p))

    dd *= tt;         // M^(1/p) * R^(mod(X,p))

    d2hfloat(dd, c);

    c.exp( c.exp()-ex );   // * R^(div(X,p))

#endif // ( SIMPLE_POW )

#ifdef  HF_PARANOIA
    jjassert( c.OK() );
#endif
}
//=================== end ======================
