#include <math.h>
#include "ap.h"
#include "apcplx.h"

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

// Overloaded mathematical functions

inline size_t min (size_t a, size_t b)
{
    return (a < b ? a : b);
}

inline size_t max (size_t a, size_t b)
{
    return (a > b ? a : b);
}


// Absolute value
apfloat abs (apcomplex z)
{
    return sqrt (norm (z));
}

// Positive integer power
apcomplex pow (apcomplex base, unsigned long exp)
{
    apcomplex r;

    if (!exp) return 1;

    while (!(exp & 1))
    {
        base *= base;
        exp >>= 1;
    }
    r = base;

    while (exp >>= 1)
    {
        base *= base;
        if (exp & 1) r *= base;
    }

    return r;
}

apcomplex pow (apcomplex base, unsigned exp)
{
    return pow (base, (unsigned long) exp);
}

// Integer power
apcomplex pow (apcomplex base, long exp)
{
    if (exp < 0)
        return 1 / pow (base, (unsigned long) -exp);
    else
        return pow (base, (unsigned long) exp);
}

apcomplex pow (apcomplex base, int exp)
{
    if (exp < 0)
        return 1 / pow (base, (unsigned long) -exp);
    else
        return pow (base, (unsigned long) exp);
}

// Square root
apcomplex sqrt (apcomplex z)
{
    if (!z.re.sign () && !z.im.sign ())
        return apfloat (new apstruct);          // Avoid division by zero
    else
        return z * invroot (z, 2);
}

// Cube root
apcomplex cbrt (apcomplex z)
{
    if (!z.re.sign () && !z.im.sign ())
        return apfloat (new apstruct);          // Avoid division by zero
    else
        return z * invroot (z * z, 3);
}

// Inverse positive integer root
apcomplex invroot (apcomplex u, unsigned n)
{
    size_t prec, minprec, maxprec, destprec = u.prec (), doubledigits, fprec;
    int k, f;
    long p, r1, r2;
    apfloat one = 1, d = n;
    apfloat x, y;
    apcomplex z;
    double r, i, m, a;

    if (!n) return 1;

    assert (u.re.sign () || u.im.sign ());      // Infinity

    // Approximate accuracy of a double
    doubledigits = (size_t) (log (9e15) / log ((double) Basedigit));

    // Initial guess accuracy
    fprec = max (doubledigits, 2 * Basedigits);

    // Calculate initial guess from u
    if (!u.im.sign () ||
        (u.re.sign () && u.im.sign () && u.re.exp () - u.im.exp () > 2 * Basedigits))
    {
        // Re z >> Im z
        apcomplex t;

        p = u.re.ap->exp / (long) n;
        r1 = u.re.ap->exp - p * (long) n;

        x = u.re;
        x.prec (fprec);
        y = u.im;
        y.prec (fprec);
        t = apcomplex (0, y / (n * x));

        x.exp (Basedigits * r1);        // Allow exponents in exess of doubles'

        if ((m = ap2double (x.ap)) >= 0.0)
        {
            r = pow (m, -1.0 / (double) n);
            i = 0.0;
        }
        else
        {
            m = pow (-m, -1.0 / (double) n);
            a = -M_PI / (double) n;
            r = m * cos (a);
            i = m * sin (a);
        }

        x = apfloat (r, fprec);
        y = apfloat (i, fprec);
        x.exp (x.exp () - p * Basedigits);
        y.exp (y.exp () - p * Basedigits);
        z = apcomplex (x, y);
        z -= z * t;                     // Must not be real
    }
    else if (!u.re.sign () ||
             (u.re.sign () && u.im.sign () && u.im.exp () - u.re.exp () > 2 * Basedigits))
    {
        // Im z >> Re z

        p = u.im.ap->exp / (long) n;
        r1 = u.im.ap->exp - p * (long) n;

        y = u.im;
        y.prec (fprec);
        y.exp (Basedigits * r1);        // Allow exponents in exess of doubles'

        if ((m = ap2double (y.ap)) >= 0.0)
        {
            m = pow (m, -1.0 / (double) n);
            a = -M_PI / (double) (2 * n);
        }
        else
        {
            m = pow (-m, -1.0 / (double) n);
            a = M_PI / (double) (2 * n);
        }

        r = m * cos (a);
        i = m * sin (a);

        x = apfloat (r, fprec);
        y = apfloat (i, fprec);
        x.exp (x.exp () - p * Basedigits);
        y.exp (y.exp () - p * Basedigits);
        z = apcomplex (x, y);
    }
    else
    {
        // Im z and Re z approximately the same

        p = u.re.ap->exp / (long) n;
        r1 = u.re.ap->exp - p * (long) n;
        r2 = u.im.ap->exp - p * (long) n;

        x = u.re;
        x.prec (fprec);
        x.exp (Basedigits * r1);        // Allow exponents in exess of doubles'

        y = u.im;
        y.prec (fprec);
        y.exp (Basedigits * r2);        // Allow exponents in exess of doubles'

        r = ap2double (x.ap);
        i = ap2double (y.ap);
        m = pow (r * r + i * i, -1.0 / (double) (2 * n));
        a = -atan2 (i, r) / (double) n;

        r = m * cos (a);
        i = m * sin (a);

        x = apfloat (r, fprec);
        y = apfloat (i, fprec);
        x.exp (x.exp () - p * Basedigits);
        y.exp (y.exp () - p * Basedigits);
        z = apcomplex (x, y);
    }

    prec = min (doubledigits, Basedigits);

    // Check if a factor of 3 should be used in length
    maxprec = rnd23up (destprec / Basedigits);
    if (maxprec != (maxprec & -maxprec))
        minprec = 3 * Basedigits;
    else
        minprec = Basedigits;

    // Highly ineffective unless precision is 2^n * Basedigits (or 3*2^n)
    if (prec < minprec)
    {
        z.re.prec (minprec + 3 * Basedigits);
        z.im.prec (minprec + 3 * Basedigits);
        while (prec < minprec)
        {
            z += z * (one - u * pow (z, n)) / d;
            prec *= 2;
        }
        prec = minprec;
    }

    // Check where the precising iteration should be done
    for (k = 0, maxprec = prec; maxprec < destprec; k++, maxprec <<= 1);
    for (f = k, minprec = prec; f; f--, minprec <<= 1)
        if (minprec >= 2 * Basedigits && (minprec - 2 * Basedigits) << f >= destprec)
            break;

    // Newton's iteration
    while (k--)
    {
        apcomplex t;

        prec *= 2;
        z.re.prec (min (prec, destprec));
        z.im.prec (min (prec, destprec));

        t = one - u * pow (z, n);
        if (k < f)
        {
            t.re.prec (prec / 2);
            t.im.prec (prec / 2);
        }

        if (n > 1)
            z += z * t / d;
        else
            z += z * t;

        // Precising iteration
        if (k == f)
        {
            if (n > 1)
                z += z * (one - u * pow (z, n)) / d;
            else
                z += z * (one - u * pow (z, n));
        }
    }

    z.re.prec (destprec);
    z.im.prec (destprec);

    return z;
}

// Arithmetic-geometric mean
// Won't work if precision is less than 2 * Basedigits
apcomplex agm (apcomplex a, apcomplex b)
{
    apcomplex t;
    size_t prec = 0, destprec = min (a.prec (), b.prec ());

    if ((!a.re.sign () && !a.im.sign ()) || (!b.re.sign () && !b.im.sign ()))   // Would not converge quadratically
        return apfloat (new apstruct);      // Zero

    assert (destprec > Basedigits);

    // First check convergence
    while (prec < Basedigits * Blocksize && 2 * prec < destprec)
    {
        t = (a + b) / 2;
        b = sqrt (a * b);
        a = t;

        prec = Basedigits * min (apeq (a.re.ap, b.re.ap), apeq (a.im.ap, b.im.ap));
    }

    // Now we know quadratic convergence
    while (2 * prec <= destprec)
    {
        t = (a + b) / 2;
        b = sqrt (a * b);
        a = t;

        prec *= 2;
    }

    return (a + b) / 2;
}

// Raw logarithm, regardless of z
// Doesn't work for really big z, but is faster if used alone for small numbers
apcomplex rawlog (apcomplex z)
{
    size_t destprec = z.prec ();
    long n = destprec / 2 + 2 * Basedigits;     // Rough estimate
    apfloat e, agme;
    apcomplex agmez;

    assert (z.re.sign () || z.im.sign ());              // Infinity

    e = apfloat (1, destprec);
    e.exp (e.exp () - n);
    z.re.exp (z.re.exp () - n);
    z.im.exp (z.im.exp () - n);

    agme = agm (1, e);
    agmez = agm (apcomplex (1), z);

    checkpi (destprec);

    return Readypi * (agmez - agme) / (2 * agme * agmez);
}

// Calculate the log using 1 / Base <= |z| < 1 and the log addition formula
// because the agm converges badly for really big z
apcomplex log (apcomplex z)
{
    size_t destprec = z.prec ();
    long tmpexp;
    apfloat t;

    if (!z.re.sign ())
        tmpexp = z.im.exp ();
    else if (!z.im.sign ())
        tmpexp = z.re.exp ();
    else
    {
        tmpexp = z.re.exp ();
        if (z.im.exp () > tmpexp) tmpexp = z.im.exp ();
    }

    checklogconst (destprec);

    z.re.exp (z.re.exp () - tmpexp);
    z.im.exp (z.im.exp () - tmpexp);

    t = Logbase;
    t.prec (destprec + Basedigits);

    return rawlog (z) + tmpexp * t;
}

// Exponent function, calculated using Newton's iteration for the inverse of log
apcomplex exp (apcomplex u)
{
    size_t prec, minprec, maxprec, destprec = u.prec (), doubledigits, fprec;
    int k, f;
    apfloat x, y;
    apcomplex z;

    if (!u.re.sign () && !u.im.sign ()) return apcomplex (1);

    checklogconst (destprec);

    // Approximate accuracy of a double
    doubledigits = (size_t) (log (9e15) / log ((double) Basedigit));

    // Initial guess accuracy
    fprec = max (doubledigits, 2 * Basedigits);

    // First handle the real part
    if (u.re.exp () < -Basedigits)
    {
        // Taylor series: exp(x) = 1 + x + x^2/2 + ...

        x = u.re;
        x.prec (-u.re.exp () + 1);
        x += 1;
        prec = -2 * u.re.exp ();

        // Highly ineffective unless precision is 2^n * Basedigits
        // Round down to nearest power of two
        prec = Basedigits * (size_t) pow (2.0, floor (log ((double) prec / Basedigits) / log (2.0)));
        x.prec (prec);
    }
    else
    {
        // Approximate starting value for iteration
        double d, i, f;

        // If u.re is too big, an overflow will occur (somewhere)
        d = ap2double (u.re.ap);
        i = floor (d);
        f = d - i;

        d = i / log ((double) Base);

        x = apfloat (exp (f) * pow ((double) Base, d - floor (d)), fprec);
        x.exp (x.exp () + Basedigits * (long) floor (d));
    }

    // Then handle the imaginary part
    if (u.im.exp () < -Basedigits)
    {
        // Taylor series: exp(z) = 1 + z + z^2/2 + ...

        y = u.im;
        y.prec (-u.im.exp () + 1);
        z = 1 + apcomplex (0, y);
        prec = -2 * u.im.exp ();

        // Highly ineffective unless precision is 2^n * Basedigits
        // Round down to nearest power of two
        prec = Basedigits * (size_t) pow (2.0, floor (log ((double) prec / Basedigits) / log (2.0)));
        z.re.prec (prec);
        z.im.prec (prec);
    }
    else
    {
        // Approximate starting value for iteration
        double d;
        long i;

        // If u.im is too big, the result will be totally inaccurate
        d = ap2double (u.im.ap) + M_PI;
        i = (long) floor (d / (2 * M_PI));
        d = fmod (d, 2 * M_PI);
        d -= M_PI;

        // Ensure that always -pi < u.im <= pi
        if (i)
        {
            apfloat t = Readypi;

            t.prec (destprec + Basedigits);
            u.im -= i * 2 * t;
            d = ap2double (u.im.ap);
        }

        z = apcomplex (apfloat (cos (d), fprec), apfloat (sin (d), fprec));
    }

    z *= x;

    prec = min (doubledigits, Basedigits);

    // Check if a factor of 3 should be used in length
    maxprec = rnd23up (destprec / Basedigits);
    if (maxprec != (maxprec & -maxprec))
        minprec = 3 * Basedigits;
    else
        minprec = Basedigits;

    // Highly ineffective unless precision is 2^n * Basedigits (or 3*2^n)
    if (prec < minprec)
    {
        z.re.prec (minprec + 3 * Basedigits);
        z.im.prec (minprec + 3 * Basedigits);
        while (prec < minprec)
        {
            z += z * (u - log (z));
            prec *= 2;
        }
        prec = minprec;
    }

    // Check where the precising iteration should be done
    for (k = 0, maxprec = prec; maxprec < destprec; k++, maxprec <<= 1);
    for (f = k, minprec = prec; f; f--, minprec <<= 1)
        if (minprec >= 3 * Basedigits && (minprec - 3 * Basedigits) << f >= destprec)
            break;

    // Newton's iteration
    while (k--)
    {
        apcomplex t;

        prec *= 2;
        // Complex log needs a bit extra precision for convergence
        z.re.prec (max (4 * Basedigits, min (prec, destprec)));
        z.im.prec (max (4 * Basedigits, min (prec, destprec)));

        t = u - log (z);
        if (k < f)
        {
            t.re.prec (prec / 2);
            t.im.prec (prec / 2);
        }

        z += z * t;

        // Precising iteration
        if (k == f)
            z += z * (u - log (z));
    }

    z.re.prec (destprec);
    z.im.prec (destprec);

    return z;
}

// Arbitrary power, calculated using log and exp
apcomplex pow (apcomplex z, apcomplex w)
{
    size_t destprec = min (z.prec (), w.prec ());

    checklogconst (destprec);

    z.re.prec (destprec);
    z.im.prec (destprec);
    w.re.prec (destprec);
    w.im.prec (destprec);

    return exp (w * log (z));
}

apcomplex pow (apcomplex z, apfloat y)
{
    size_t destprec = min (z.prec (), y.prec ());

    checklogconst (destprec);

    z.re.prec (destprec);
    z.im.prec (destprec);
    y.prec (destprec);

    return exp (y * log (z));
}

apcomplex pow (apfloat x, apcomplex w)
{
    size_t destprec = min (x.prec (), w.prec ());

    checklogconst (destprec);

    x.prec (destprec);
    w.re.prec (destprec);
    w.im.prec (destprec);

    return exp (w * log (x));
}


// Trigonometric and hyperbolic functions and their inverses

apcomplex acos (apcomplex z)
{
    apcomplex i = apcomplex (0, 1), w;

    if (z.re.sign () >= 0)
        w = i * log (z + sqrt (z * z - 1));
    else
        w = -i * log (z - sqrt (z * z - 1));

    if (z.re.sign () * z.im.sign () >= 0)
        return -w;
    else
        return w;
}

apcomplex acosh (apcomplex z)
{
    apcomplex w;

    if (z.re.sign () >= 0)
        return log (z + sqrt (z * z - 1));
    else
        return log (z - sqrt (z * z - 1));
}

apcomplex asin (apcomplex z)
{
    apcomplex i = apcomplex (0, 1);

    if (z.im.sign () >= 0)
        return i * log (sqrt (1 - z * z) - i * z);
    else
        return -i * log (i * z + sqrt (1 - z * z));
}

apcomplex asinh (apcomplex z)
{
    if (z.re.sign () >= 0)
        return log (sqrt (z * z + 1) + z);
    else
        return -log (sqrt (z * z + 1) - z);
}

apcomplex atan (apcomplex z)
{
    apcomplex i = apcomplex (0, 1);

    return log ((i + z) / (i - z)) * i / 2;
}

apcomplex atanh (apcomplex z)
{
    return log ((1 + z) / (1 - z)) / 2;
}

apcomplex cos (apcomplex z)
{
    apcomplex i = apcomplex (0, 1);
    apcomplex w = exp (i * z);

    return (w + 1 / w) / 2;
}

apcomplex cosh (apcomplex z)
{
    apcomplex w = exp (z);

    return (w + 1 / w) / 2;
}

apcomplex sin (apcomplex z)
{
    apcomplex i = apcomplex (0, 1);
    apcomplex w = exp (i * z);

    return (1 / w - w) * i / 2;
}

apcomplex sinh (apcomplex z)
{
    apcomplex w = exp (z);

    return (w - 1 / w) / 2;
}

apcomplex tan (apcomplex z)
{
    apcomplex i = apcomplex (0, 1);
    apcomplex w = exp (2 * i * z);

    return i * (1 - w) / (1 + w);
}

apcomplex tanh (apcomplex z)
{
    apcomplex w = exp (2 * z);

    return (w - 1) / (w + 1);
}


// Real trigonometric and hyperbolic functions and their inverses
// use complex functions

apfloat acos (apfloat x)
{
    apcomplex i = apcomplex (0, 1);

    return imag (log (x + i * sqrt (1 - x * x)));
}

apfloat asin (apfloat x)
{
    apcomplex i = apcomplex (0, 1);

    return -imag (log (sqrt (1 - x * x) - i * x));
}

apfloat atan (apfloat x)
{
    apcomplex i = apcomplex (0, 1);

    return imag (log ((i - x) / (i + x))) / 2;
}

apfloat atan2 (apfloat x, apfloat y)
{
    long tmpexp;
    apfloat t;

    if (!x.sign ())
    {
        assert (y.sign ());

        checkpi (y.prec ());

        t = Readypi;
        t.prec (y.prec ());

        return y.sign () * t / 2;
    }
    else if (!y.sign ())
    {
        if (x.sign () > 0) return 0;

        checkpi (x.prec ());

        t = Readypi;
        t.prec (x.prec ());

        return t;
    }
    else
    {
        tmpexp = x.exp ();
        if (y.exp () > tmpexp) tmpexp = y.exp ();
    }

    x.exp (x.exp () - tmpexp);
    y.exp (y.exp () - tmpexp);

    return imag (rawlog (apcomplex (x, y)));
}

apfloat cos (apfloat x)
{
    return real (exp (apcomplex (0, x)));
}

apfloat sin (apfloat x)
{
    return imag (exp (apcomplex (0, x)));
}

apfloat tan (apfloat x)
{
    apcomplex w = exp (apcomplex (0, x));

    return imag (w) / real (w);
}
