#include "ap.h"

modint t0, t1, t2;
rawtype m01[2], m02[2], m12[2], mm[3], cc[3];
float chopper53 = 4503599627370496.0;                   // 2^52
double dtmp0, dtmp1, dtmp2;
double dmodulus0, dmodulus1, dmodulus2;
double imodulus0, imodulus1, imodulus2;

extern "C" void crtblock (size_t l, modint *buf1, modint *buf2, modint *buf3);

// Carry & Chinese Remainder Theorem for fnt-multiplication (and square)
// Returns 1 if right shift ocurred
// Assume that ds1 will be in memory if possible

int carrycrt (apstruct *ds1, apstruct *s2, apstruct *s3, size_t rsize)  // Low to high
{
    size_t l, t, p = rsize, r;
    modint *buf1, *buf2, *buf3;
    rawtype carry, tmp1, tmp2;

    cc[0] = cc[1] = cc[2] = 0;

    dmodulus0 = (double) moduli[0];
    dmodulus1 = (double) moduli[1];
    dmodulus2 = (double) moduli[2];

    imodulus0 = 1.0 / (double) moduli[0];
    imodulus1 = 1.0 / (double) moduli[1];
    imodulus2 = 1.0 / (double) moduli[2];

    setmodulus (moduli[0]);
    t0 = modint (1) / (modint (moduli[1]) * moduli[2]);

    //Now moduli[0] is larger than moduli[1], so special care must be taken
    setmodulus (moduli[1]);
    tmp1 = moduli[0];
    while (tmp1 >= modint::modulus) tmp1 -= modint::modulus;
    t1 = modint (1) / (modint (tmp1) * moduli[2]);

    //Now moduli[0] and moduli[1] are larger than moduli[2] again
    setmodulus (moduli[2]);
    tmp1 = moduli[0];
    while (tmp1 >= modint::modulus) tmp1 -= modint::modulus;
    tmp2 = moduli[1];
    while (tmp2 >= modint::modulus) tmp2 -= modint::modulus;
    t2 = modint (1) / (modint (tmp1) * tmp2);

    m01[0] = moduli[0];
    m01[1] = bigmul (m01, m01, moduli[1], 1);

    m02[0] = moduli[0];
    m02[1] = bigmul (m02, m02, moduli[2], 1);

    m12[0] = moduli[1];
    m12[1] = bigmul (m12, m12, moduli[2], 1);

    mm[2] = bigmul (mm, m01, moduli[2], 2);

    while (p)
    {
        l = (p < Blocksize ? p : Blocksize);
        p -= l;
        buf1 = ds1->getdata (p, l);
        buf2 = s2->getdata (p, l);
        buf3 = s3->getdata (p, l);

        crtblock (l, buf1, buf2, buf3);

        s3->cleardata ();
        s2->cleardata ();
        ds1->putdata ();
    }

    carry = cc[0];

    if (carry != 0)
    {
        p = ds1->size;
        r = 0;

        tmp1 = carry;

        while (p)
        {
            l = (p < Maxblocksize ? p : Maxblocksize);
            p -= l;
            buf1 = ds1->getdata (r, l);
            r += l;
            for (t = 0; t < l; t++)
            {
                tmp2 = buf1[t];
                buf1[t] = tmp1;
                tmp1 = tmp2;
            }
            ds1->putdata ();
        }

        return 1;
    }

    return 0;
}
