#include "ap.h"


// Highly parallel three modular multiplications using the FPU
inline void modmul3 (modint &a1, modint &b1, modint &a2, modint &b2, modint &a3, modint &b3)
{
    asm ("pushl %3; pushl %4;
          pushl %5; pushl %6;
          pushl %7; pushl %8;
          fildl 20(%%esp); fildl 16(%%esp);
          fildl 12(%%esp); fxch %%st(2);
          fmulp %%st, %%st(1); fxch %%st(1);
          fildl 8(%%esp); fildl 4(%%esp); fxch %%st(2);
          fmulp %%st, %%st(1); fxch %%st(1);
          fildl (%%esp); addl $12, %%esp;
          fmulp %%st, %%st(1);
          fld %%st(2); fmul %%st(4), %%st;
          fld %%st(2); fmul %%st(5), %%st;
          fld %%st(2); fmul %%st(6), %%st; fxch %%st(2);
          fadds chopper64; fxch %%st(1);
          fadds chopper64; fxch %%st(2);
          fadds chopper64; fxch %%st(1);
          fsubs chopper64; fxch %%st(2);
          fsubs chopper64; fxch %%st(1);
          fsubs chopper64; fxch %%st(2);
          fmull dmodulus; fxch %%st(1);
          fmull dmodulus; fxch %%st(1);
          fsubrp %%st, %%st(5); fxch %%st(1);
          fmull dmodulus; fxch %%st(1);
          fsubrp %%st, %%st(3);
          fsubrp %%st, %%st(1); fxch %%st(2);
          fistpl (%%esp);
          fistpl 4(%%esp);
          fistpl 8(%%esp);
          popl %0;
          popl %1;
          popl %2"
                            : "=rm" (a1), "=rm" (a2), "=rm" (a3)
                            : "0" (a1), "rm" (b1), "1" (a2), "rm" (b2), "2" (a3), "rm" (b3)
                            : "cc");
}

// Highly parallel three modular squares using the FPU
inline void modsqr3 (modint &a1, modint &a2, modint &a3)
{
    asm ("pushl %3; pushl %4; pushl %5;
          fildl 8(%%esp); fildl 4(%%esp); fxch %%st(1)
          fmul %%st; fxch %%st(1);
          fildl (%%esp); fxch %%st(1);
          fmul %%st; fxch %%st(1);
          fmul %%st;
          fld %%st(2); fmul %%st(4), %%st;
          fld %%st(2); fmul %%st(5), %%st;
          fld %%st(2); fmul %%st(6), %%st; fxch %%st(2);
          fadds chopper64; fxch %%st(1);
          fadds chopper64; fxch %%st(2);
          fadds chopper64; fxch %%st(1);
          fsubs chopper64; fxch %%st(2);
          fsubs chopper64; fxch %%st(1);
          fsubs chopper64; fxch %%st(2);
          fmull dmodulus; fxch %%st(1);
          fmull dmodulus; fxch %%st(1);
          fsubrp %%st, %%st(5); fxch %%st(1);
          fmull dmodulus; fxch %%st(1);
          fsubrp %%st, %%st(3);
          fsubrp %%st, %%st(1); fxch %%st(2);
          fistpl (%%esp);
          fistpl 4(%%esp);
          fistpl 8(%%esp);
          popl %0;
          popl %1;
          popl %2"
                            : "=rm" (a1), "=rm" (a2), "=rm" (a3)
                            : "0" (a1), "1" (a2), "2" (a3)
                            : "cc");
}


// Linear multiplication in the number theoretic domain
// Assume that ds1 will be in memory if possible
void multiplyinplace (apstruct *ds1, apstruct *s2)
{
    size_t t, p = ds1->size, l, r = 0;
    modint *buf1, *buf2;

    while (p)
    {
        l = (p < Blocksize ? p : Blocksize);
        p -= l;
        buf1 = ds1->getdata (r, l);
        buf2 = s2->getdata (r, l);
        r += l;

        for (t = 0; t < l - 2; t += 3)
            modmul3 (buf1[t], buf2[t], buf1[t + 1], buf2[t + 1], buf1[t + 2], buf2[t + 2]);
        for (; t < l; t++)
            buf1[t] *= buf2[t];

        ds1->putdata ();
        s2->cleardata ();
    }
}

// Linear squaring in the number theoretic domain
// Assume that ds will be in memory if possible
void squareinplace (apstruct *ds)
{
    size_t t, p = ds->size, l, r = 0;
    modint *buf;

    while (p)
    {
        l = (p < Maxblocksize ? p : Maxblocksize);
        p -= l;
        buf = ds->getdata (r, l);
        r += l;

        for (t = 0; t < l - 2; t += 3)
            modsqr3 (buf[t], buf[t + 1], buf[t + 2]);
        for (; t < l; t++)
            buf[t] *= buf[t];

        ds->putdata ();
    }
}

// Convolution of the mantissas in s1 and s2
// Use sizes s1size and s2size correspondingly
// Result size is rsize
// Transform length is n
// *i = 1 if right shift occurred, otherwise 0
apstruct *convolution (apstruct *s1, apstruct *s2, size_t rsize, size_t s1size, size_t s2size, size_t n, int *i)
{
    int location = (n > Maxblocksize ? DISK : MEMORY);

    apstruct *tmp1;
    apstruct *tmp2;
    apstruct *tmp3;
    apstruct *tmp4;

    if (MAXTRANSFORMLENGTH)
        assert (n <= MAXTRANSFORMLENGTH);           // Otherwise it won't work

    setmodulus (moduli[2]);

    tmp2 = new apstruct (*s2, s2size, location, n);

    if (location != MEMORY)
    {
        fstream &fs = tmp2->openstream ();
        tabletwopassfnttrans (fs, primitiveroots[2], 1, n);
        tmp2->closestream ();
    }
    else
    {
        modint *data = tmp2->getdata (0, n);
        tablesixstepfnttrans (data, primitiveroots[2], 1, n);
        tmp2->putdata ();
        tmp2->relocate (DEFAULT);
    }

    tmp1 = new apstruct (*s1, s1size, location, n);

    if (location != MEMORY)
    {
        fstream &fs = tmp1->openstream ();
        tabletwopassfnttrans (fs, primitiveroots[2], 1, n);
        tmp1->closestream ();
    }
    else
    {
        modint *data = tmp1->getdata (0, n);
        tablesixstepfnttrans (data, primitiveroots[2], 1, n);
        tmp1->putdata ();
    }

    multiplyinplace (tmp1, tmp2);

    delete tmp2;

    if (location != MEMORY)
    {
        fstream &fs = tmp1->openstream ();
        itabletwopassfnttrans (fs, primitiveroots[2], -1, n);
        tmp1->closestream ();
    }
    else
    {
        modint *data = tmp1->getdata (0, n);
        itablesixstepfnttrans (data, primitiveroots[2], -1, n);
        tmp1->putdata ();
        tmp1->relocate (DEFAULT);
    }

    setmodulus (moduli[1]);

    tmp3 = new apstruct (*s2, s2size, location, n);

    if (location != MEMORY)
    {
        fstream &fs = tmp3->openstream ();
        tabletwopassfnttrans (fs, primitiveroots[1], 1, n);
        tmp3->closestream ();
    }
    else
    {
        modint *data = tmp3->getdata (0, n);
        tablesixstepfnttrans (data, primitiveroots[1], 1, n);
        tmp3->putdata ();
        tmp3->relocate (DEFAULT);
    }

    tmp2 = new apstruct (*s1, s1size, location, n);

    if (location != MEMORY)
    {
        fstream &fs = tmp2->openstream ();
        tabletwopassfnttrans (fs, primitiveroots[1], 1, n);
        tmp2->closestream ();
    }
    else
    {
        modint *data = tmp2->getdata (0, n);
        tablesixstepfnttrans (data, primitiveroots[1], 1, n);
        tmp2->putdata ();
    }

    multiplyinplace (tmp2, tmp3);

    delete tmp3;

    if (location != MEMORY)
    {
        fstream &fs = tmp2->openstream ();
        itabletwopassfnttrans (fs, primitiveroots[1], -1, n);
        tmp2->closestream ();
    }
    else
    {
        modint *data = tmp2->getdata (0, n);
        itablesixstepfnttrans (data, primitiveroots[1], -1, n);
        tmp2->putdata ();
        tmp2->relocate (DEFAULT);
    }

    setmodulus (moduli[0]);

    tmp4 = new apstruct (*s2, s2size, location, n);

    if (location != MEMORY)
    {
        fstream &fs = tmp4->openstream ();
        tabletwopassfnttrans (fs, primitiveroots[0], 1, n);
        tmp4->closestream ();
    }
    else
    {
        modint *data = tmp4->getdata (0, n);
        tablesixstepfnttrans (data, primitiveroots[0], 1, n);
        tmp4->putdata ();
        tmp4->relocate (DEFAULT);
    }

    tmp3 = new apstruct (*s1, s1size, location, n);

    if (location != MEMORY)
    {
        fstream &fs = tmp3->openstream ();
        tabletwopassfnttrans (fs, primitiveroots[0], 1, n);
        tmp3->closestream ();
    }
    else
    {
        modint *data = tmp3->getdata (0, n);
        tablesixstepfnttrans (data, primitiveroots[0], 1, n);
        tmp3->putdata ();
    }

    multiplyinplace (tmp3, tmp4);

    delete tmp4;

    if (location != MEMORY)
    {
        fstream &fs = tmp3->openstream ();
        itabletwopassfnttrans (fs, primitiveroots[0], -1, n);
        tmp3->closestream ();
    }
    else
    {
        modint *data = tmp3->getdata (0, n);
        itablesixstepfnttrans (data, primitiveroots[0], -1, n);
        tmp3->putdata ();
    }

    *i = carrycrt (tmp3, tmp2, tmp1, rsize);

    delete tmp1;
    delete tmp2;

    // Return value can remain in memory

    return tmp3;
}

// Autoconvolution of the mantissa in s
// Use size ssize for s
// Result size is rsize
// Transform length is n
// *i = 1 if right shift occurred, otherwise 0
apstruct *autoconvolution (apstruct *s, size_t rsize, size_t ssize, size_t n, int *i)
{
    int location = (n > Maxblocksize ? DISK : MEMORY);

    apstruct *tmp1;
    apstruct *tmp2;
    apstruct *tmp3;

    if (MAXTRANSFORMLENGTH)
        assert (n <= MAXTRANSFORMLENGTH);           // Otherwise it won't work

    setmodulus (moduli[2]);

    tmp1 = new apstruct (*s, ssize, location, n);

    if (location != MEMORY)
    {
        fstream &fs = tmp1->openstream ();
        tabletwopassfnttrans (fs, primitiveroots[2], 1, n);
        tmp1->closestream ();
    }
    else
    {
        modint *data = tmp1->getdata (0, n);
        tablesixstepfnttrans (data, primitiveroots[2], 1, n);
        tmp1->putdata ();
    }

    squareinplace (tmp1);

    if (location != MEMORY)
    {
        fstream &fs = tmp1->openstream ();
        itabletwopassfnttrans (fs, primitiveroots[2], -1, n);
        tmp1->closestream ();
    }
    else
    {
        modint *data = tmp1->getdata (0, n);
        itablesixstepfnttrans (data, primitiveroots[2], -1, n);
        tmp1->putdata ();
        tmp1->relocate (DEFAULT);
    }

    setmodulus (moduli[1]);

    tmp2 = new apstruct (*s, ssize, location, n);

    if (location != MEMORY)
    {
        fstream &fs = tmp2->openstream ();
        tabletwopassfnttrans (fs, primitiveroots[1], 1, n);
        tmp2->closestream ();
    }
    else
    {
        modint *data = tmp2->getdata (0, n);
        tablesixstepfnttrans (data, primitiveroots[1], 1, n);
        tmp2->putdata ();
    }

    squareinplace (tmp2);

    if (location != MEMORY)
    {
        fstream &fs = tmp2->openstream ();
        itabletwopassfnttrans (fs, primitiveroots[1], -1, n);
        tmp2->closestream ();
    }
    else
    {
        modint *data = tmp2->getdata (0, n);
        itablesixstepfnttrans (data, primitiveroots[1], -1, n);
        tmp2->putdata ();
        tmp2->relocate (DEFAULT);
    }

    setmodulus (moduli[0]);

    tmp3 = new apstruct (*s, ssize, location, n);

    if (location != MEMORY)
    {
        fstream &fs = tmp3->openstream ();
        tabletwopassfnttrans (fs, primitiveroots[0], 1, n);
        tmp3->closestream ();
    }
    else
    {
        modint *data = tmp3->getdata (0, n);
        tablesixstepfnttrans (data, primitiveroots[0], 1, n);
        tmp3->putdata ();
    }

    squareinplace (tmp3);

    if (location != MEMORY)
    {
        fstream &fs = tmp3->openstream ();
        itabletwopassfnttrans (fs, primitiveroots[0], -1, n);
        tmp3->closestream ();
    }
    else
    {
        modint *data = tmp3->getdata (0, n);
        itablesixstepfnttrans (data, primitiveroots[0], -1, n);
        tmp3->putdata ();
    }

    *i = carrycrt (tmp3, tmp2, tmp1, rsize);

    delete tmp1;
    delete tmp2;

    // Return value can remain in memory

    return tmp3;
}
