
#include <math.h>
#include <iostream.h>
#include <assert.h>
#include <stdlib.h>

#include "hfloat.h"
#include "auxtempl.h"  // min(), sign()
#include "auxbit.h"  // ld()

void approx_exp(const hfloat &d, hfloat &c);


// for debug:
#define  PR(x)
#define  SHRT  10

// whether to check that the last step in iteration really 
// produced the full precision (expensive!): 
const int check_last_step = 0;


int
exp_iteration(const hfloat &di, hfloat &x, ulong startprec/*=0*/)
//
//   the equation exp(d)==x*(exp(d-log(x))) gives the
//
//   iteration for exp(d):
//
//   x <- x*(1+y+y^2/2!+y^3/3!+...) where y := d-log(x)
//   or:
//   x += x*(  y+y^2/2!+y^3/3!+...)
//
{
    PR( cout << __PRETTY_FUNCTION__ << endl; );
    PR( print("\n arg d=",di,SHRT); );

    if ( di.is_zero() )
    {
        x = 1;
        return 0;
    }

    hfloat d(di);  // jjnote: ugly: only needed when di.data()==x.data()

    static const int order = 20;
    static const ulong deg = order+1;
    static long num[deg+1];
    static long den[deg+1];

    // setup ratpoly for exp()-1:
    num[0] = 0;
    den[0] = 1;
    for (ulong k=1; k<=deg; ++k)
    {
        num[k] = 1;
        den[k] = den[k-1] * k;  // k!
    }

    ulong pg = x.prec();  // precision goal
    const ulong rpg = (pg>=6?pg-6:1); // jjnote: magic: realistic precision goal

    hfloat t(pg);
    const int max_rec = (int)(8+log(pg)/log(order));

    int  xs = d.sign();
    d.sign( +1 );

    startprec = 0;  // jjnote: unused  (avoid compiler warning)
    ulong ap;
    if ( startprec!=0 )
    {
        ap = startprec;
    }
    else
    {
        approx_exp(d,x);        PR( print("\n ==  x0=",x,SHRT); );
        ap = 3;  // assume 3 limbs are correct
    }

    ulong cp;  // calculation precision
    cp = next_pow_of_2( ap );
    cp = min(cp,pg);

    int j;
    for(j=0; j<max_rec; ++j)  // ---------------- ITERATION --------------
    {
        if ( ap > cp/order )
        {
            cp *= order;
            cp = min(cp,pg);
        }

        t.prec(cp);
        x.prec(cp);
        d.prec(cp);

        log(x,t);     PR( print("\n log(x)=",t,SHRT); );

        sub(d,t,t);    PR( print("\n y:= d-log(x)=",t,SHRT); );

        assert( t.exp()<=0 );
        ap = -t.exp(); PR( cout << "\n ap=" << ap << endl; );
//        cout << "\n j=" << j << "  ap=" << ap << endl;

        ratpoly(t,order+1,num,den,t);  PR( print("\n  poly=",t,SHRT); );

        mul(x,t,t);   PR( print("\n  x*poly=",t,SHRT); );

        add(x,t,x);   PR( print("\n x+x*poly=",x,SHRT); );


        if ( ap>=rpg )  break;
        if ( !check_last_step && (x.prec()>=pg) )  break;
    }
    // ---------------- end of ITERATION ---------------------

    assert( (j<=max_rec) *(int)"NO CONVERGENCE" );


    if ( xs==-1 )
    {
        inv(x,t);
        x = t;
    }

    return j;
}
//========================= end ==========================



#define  SIMPLE_EXP  0 // 0 or 1;  default is 0

void
approx_exp(const hfloat &d, hfloat &c)
//
// c = exp(d)
//
// to avoid overflow for big exponents use:
// set c := m*r^x
// then d = log(m) + x * log(r)
//
// lr = log(r)
// x  = floor(d/lr)  ==  div(d,lr)
// m  = exp(d-x*lr)  ==  exp(mod(d,lr))
//
// jjnote: hfloat d must fit into double
{
    double dd;
    hfloat2d(d,dd);

#if ( SIMPLE_EXP==1 )
#warning 'FYI: SIMPLE_EXP==1 in approx_exp()'
    dd = exp(dd);
    d2hfloat(dd,c);
#else
    int s = sign(dd);
    dd = fabs(dd);

    double lr = log( hfloat::radix() );
    double x = floor( dd/lr );
    double m = exp( dd-x*lr );

    d2hfloat(m,c);
    c.exp( c.exp()+(long)x );

    if ( s<0 )  inv(c,c);

#endif // SIMPLE_EXP
}
//=================== end ========================

