#include "ap.h"

modint t0, t1, t2;
rawtype m01[2], m02[2], m12[2], mm[3], cs[3], cc[3];

extern "C" void crtblock (size_t l, modint *buf1, modint *buf2, modint *buf3);

asm ("
    .globl _crtblock

.align 4
_crtblock:
    pushl %ebx
    pushl %ecx
    pushl %edx
    pushl %esi
    pushl %edi
    pushl %ebp

    movl 28(%esp), %ecx
    andl %ecx, %ecx
    jz crtblockend

    crtloop:

    decl %ecx
    movl 32(%esp), %ebx
    movl _t0, %eax
    mull (%ebx, %ecx, 4)
    divl _moduli
    movl %edx, %ebx

    movl _m12, %eax
    mull %ebx
    movl %eax, _cs
    movl %edx, %esi
    movl _m12+4, %eax
    mull %ebx
    add %esi, %eax
    adc $0, %edx
    movl %eax, _cs+4
    movl %edx, _cs+8

    movl 36(%esp), %ebx
    movl _t1, %eax
    mull (%ebx, %ecx, 4)
    divl _moduli+4
    movl %edx, %ebx

    movl _m02, %eax
    mull %ebx
    movl %eax, %edi
    movl %edx, %esi
    movl _m02+4, %eax
    mull %ebx
    add %esi, %eax
    adc $0, %edx

    movl _cs, %ebx
    movl _cs+4, %ebp
    addl %edi, %ebx
    adcl %eax, %ebp
    movl _cs+8, %edi
    adcl %edx, %edi
    movl _mm, %esi
    movl _mm+4, %eax
    movl _mm+8, %edx

    cmpl %edx, %edi
    jae crtsub1
    movl %ebx, _cs
    jb crtnosub1
    cmpl %eax, %ebp
    jae crtsub1
    jb crtnosub1
    cmpl %esi, %ebx
    jb crtnosub1

    crtsub1:
    subl %esi, %ebx
    sbbl %eax, %ebp
    movl %ebx, _cs
    sbbl %edx, %edi

    crtnosub1:
    movl %ebp, _cs+4
    movl %edi, _cs+8

    movl 40(%esp), %ebx
    movl _t2, %eax
    mull (%ebx, %ecx, 4)
    divl _moduli+8
    movl %edx, %ebx

    movl _m01, %eax
    mull %ebx
    movl %eax, %edi
    movl %edx, %esi
    movl _m01+4, %eax
    mull %ebx
    add %esi, %eax
    adc $0, %edx

    movl _cs, %ebx
    movl _cs+4, %ebp
    addl %edi, %ebx
    adcl %eax, %ebp
    movl _cs+8, %edi
    adcl %edx, %edi
    movl _mm, %esi
    movl _mm+4, %eax
    movl _mm+8, %edx

    cmpl %edx, %edi
    jae crtsub2
    jb crtnosub2
    cmpl %eax, %ebp
    jae crtsub2
    jb crtnosub2
    cmpl %esi, %ebx
    jb crtnosub2

    crtsub2:
    subl %esi, %ebx
    sbbl %eax, %ebp
    sbbl %edx, %edi

    crtnosub2:

    movl _cc, %esi
    movl _cc+4, %eax
    addl %ebx, %esi
    adcl %ebp, %eax
    movl _cc+8, %edx
    adcl %edi, %edx

    divl _Base
    movl $0, _cc+8
    movl %eax, _cc+4
    movl %esi, %eax
    movl 32(%esp), %ebx
    divl _Base
    movl %eax, _cc
    movl %edx, (%ebx, %ecx, 4)

    andl %ecx, %ecx
    jnz crtloop

    crtblockend:

    popl %ebp
    popl %edi
    popl %esi
    popl %edx
    popl %ecx
    popl %ebx

    ret");


// Carry & Chinese Remainder Theorem for fnt-multiplication (and square)
// Returns 1 if right shift ocurred
// Assume that ds1 will be in memory if possible

int carrycrt (apstruct *ds1, apstruct *s2, apstruct *s3, size_t rsize)  // Low to high
{
    size_t l, t, p = rsize, r;
    modint *buf1, *buf2, *buf3;
    rawtype carry, tmp1, tmp2;

    cc[0] = cc[1] = cc[2] = 0;

    setmodulus (moduli[0]);
    t0 = modint (1) / (modint (moduli[1]) * moduli[2]);

    //Now moduli[0] is larger than moduli[1], so special care must be taken
    setmodulus (moduli[1]);
    tmp1 = moduli[0];
    while (tmp1 >= modint::modulus) tmp1 -= modint::modulus;
    t1 = modint (1) / (modint (tmp1) * moduli[2]);

    //Now moduli[0] and moduli[1] are larger than moduli[2] again
    setmodulus (moduli[2]);
    tmp1 = moduli[0];
    while (tmp1 >= modint::modulus) tmp1 -= modint::modulus;
    tmp2 = moduli[1];
    while (tmp2 >= modint::modulus) tmp2 -= modint::modulus;
    t2 = modint (1) / (modint (tmp1) * tmp2);

    m01[0] = moduli[0];
    m01[1] = bigmul (m01, m01, moduli[1], 1);

    m02[0] = moduli[0];
    m02[1] = bigmul (m02, m02, moduli[2], 1);

    m12[0] = moduli[1];
    m12[1] = bigmul (m12, m12, moduli[2], 1);

    mm[2] = bigmul (mm, m01, moduli[2], 2);

    while (p)
    {
        l = (p < Blocksize ? p : Blocksize);
        p -= l;
        buf1 = ds1->getdata (p, l);
        buf2 = s2->getdata (p, l);
        buf3 = s3->getdata (p, l);

        crtblock (l, buf1, buf2, buf3);

        s3->cleardata ();
        s2->cleardata ();
        ds1->putdata ();
    }

    carry = cc[0];

    if (carry != 0)
    {
        p = ds1->size;
        r = 0;

        tmp1 = carry;

        while (p)
        {
            l = (p < Maxblocksize ? p : Maxblocksize);
            p -= l;
            buf1 = ds1->getdata (r, l);
            r += l;
            for (t = 0; t < l; t++)
            {
                tmp2 = buf1[t];
                buf1[t] = tmp1;
                tmp1 = tmp2;
            }
            ds1->putdata ();
        }

        return 1;
    }

    return 0;
}
