#include <windows.h>
#include <string.h>

static const DWORD WIN_INFINITE = INFINITE;        // Doubly defined in windows.h and apfloat.h
#undef INFINITE

#include "ap.h"


// Maximum transform size to calculate without parallelizing
// Note that there is a severe overhead for starting new threads
const size_t USE_SINGLE_FNT = 1024;

// Highly parallel three modular multiplications using the FPU
extern "C" void modmul3 (modint *a1, modint *b1, modint *a2, modint *b2, modint *a3, modint *b3);

// Partially parallel two modular multiplications using the FPU
extern "C" void modmul2 (modint *a1, modint *b1, modint *a2, modint *b2);

typedef void (*fntfunc) (modint data[], modint wtable[], size_t ptable[], size_t nn, int s);
typedef struct
{
    fntfunc func;
    size_t count;
    modint *data;
    modint *wtable;
    size_t *ptable;
    size_t nn;
    int s;
} fntargs;

DWORD fntthread (LPVOID x)
{
    size_t t;
    fntargs args;

    memcpy (&args, x, sizeof (args));

    setmodulus (modint::modulus);   // Needed to initialize the FPU in this thread

    for (t = 0; t < args.count; t++, args.data += args.nn)
        args.func (args.data, args.wtable, args.ptable, args.nn, args.s);

    return 0;
}

void setthreadaffinity (HANDLE hThread, size_t proc)
{
    HANDLE hProcess;
    DWORD threadmask, procmask, sysmask;

    hProcess = GetCurrentProcess ();

    // Get available processors
    GetProcessAffinityMask (hProcess, &procmask, &sysmask);

    if (proc == (size_t) DEFAULT)
    {
        // Set to run on all processors
        SetThreadAffinityMask (hThread, procmask);
        return;
    }

    // Find first valid processor for this process
    for (threadmask = 1; !(threadmask & procmask); threadmask <<= 1);

    // Set the processor from the first valid processor
    threadmask <<= proc;

    // Set processor on which to run the thread
    if (threadmask & procmask)                      // Only if the processor is valid
        SetThreadAffinityMask (hThread, threadmask);
}

// Dispatch a number of threads to calculate fnts in parallel
void dispatchfnt (fntfunc func, size_t count, modint data[], modint wtable[], size_t ptable[], size_t nn, int s = 1)
{
    size_t t, maxthreads = min (NProcessors, count);
    DWORD dwThreadId;
    fntargs *args = new fntargs[maxthreads];
    HANDLE *handles = new HANDLE[maxthreads - 1];
    modint *unalignedwtable, *alignedwtable;
    size_t *unalignedptable, *alignedptable, cacheburst = sizeof (modint) * Cacheburstblocksize;

    unalignedwtable = new modint[maxthreads * nn + cacheburst / sizeof (modint) - 1];
    if (s) unalignedptable = new size_t[maxthreads * nn + cacheburst / sizeof (size_t) - 1];

    // Memory blocks aligned at the beginning of a cache line
    alignedwtable = (modint *) (((size_t) unalignedwtable + cacheburst - 1) & -cacheburst);
    if (s) alignedptable = (size_t *) (((size_t) unalignedptable + cacheburst - 1) & -cacheburst);

    for (t = 0; t < maxthreads; t++)
    {
        memcpy (alignedwtable + t * nn, wtable, sizeof (modint) * nn);
        if (s) memcpy (alignedptable + t * nn, ptable, sizeof (size_t) * nn);

        // Note that NProcessors and maxthreads may be e.g. 3 which does not divide count
        args[t].func = func;
        args[t].count = (t + 1) * count / maxthreads - t * count / maxthreads;
        args[t].data = data + t * count / maxthreads * nn;
        args[t].wtable = alignedwtable + t * nn;
        args[t].ptable = (s ? alignedptable + t * nn : 0);
        args[t].nn = nn;
        args[t].s = s;

        if (t < maxthreads - 1)
        {
            handles[t] = CreateThread (NULL, 0, (LPTHREAD_START_ROUTINE) fntthread, args + t, 0, &dwThreadId);
            setthreadaffinity (handles[t], t);
        }
        else
        {
            setthreadaffinity (GetCurrentThread (), t);
            fntthread (args + t);
            setthreadaffinity (GetCurrentThread (), (size_t) DEFAULT);
        }
    }

    WaitForMultipleObjects (maxthreads - 1, handles, TRUE, WIN_INFINITE);  // Wait for all threads to end

    for (t = 0; t < maxthreads - 1; t++)
        CloseHandle (handles[t]);

    if (s) delete[] unalignedptable;
    delete[] unalignedwtable;
    delete[] handles;
    delete[] args;
}

// The "six-step" fnt, but doesn't transpose or scramble (for convolution only)

void tablesixstepfnttrans2 (modint data[], modint pr, int isign, size_t nn)
{
    size_t n1, n2, j, k;
    modint w, tmp, tmp2, tmp3, tmp4, tmp5, *p1, *p2;

    if (nn < 2) return;

    for (n1 = 1, n2 = 0; n1 < nn; n1 <<= 1, n2++);
    n1 = n2 >> 1;
    n2 -= n1;

    n1 = 1 << n1;
    n2 = 1 << n2;

    // n2 >= n1

    modint *wtable = new modint[n2];
    size_t *ptable = new size_t[n1];

    if (isign > 0)
        w = pow (pr, modint::modulus - 1 - (modint::modulus - 1) / nn);
    else
        w = pow (pr, (modint::modulus - 1) / nn);

    // treat the input data as a n1 x n2 matrix

    // first transpose the matrix

    transpose (data, n1, n2);

    // then do n2 transforms of length n1

    // init tables

    tmp = pow (w, nn / n1);
    tmp2 = 1;
    k = 0;
    if (n1 > 2)
    {
        tmp3 = tmp;
        tmp4 = tmp * tmp;
        tmp5 = tmp * tmp4;
        for (; k < n1 - 2; k += 3)
        {
            wtable[k] = tmp2;
            wtable[k + 1] = tmp3;
            wtable[k + 2] = tmp4;

            modmul3 (&tmp2, &tmp5, &tmp3, &tmp5, &tmp4, &tmp5);
        }
    }
    for (; k < n1; k++)
    {
        wtable[k] = tmp2;
        tmp2 *= tmp;
    }

    initscrambletable (ptable, n1);

    if (NProcessors <= 1 || nn <= USE_SINGLE_FNT)
        for (k = 0, p1 = data; k < n2; k++, p1 += n1)
            tablefnt (p1, wtable, ptable, n1);
    else
        dispatchfnt (tablefnt, n2, data, wtable, ptable, n1);


    // transpose the matrix

    transpose (data, n2, n1);

    // then multiply the matrix A_jk by exp(isign * -2 pi i j k / nn)

    tmp = w;
    for (j = 1, p1 = data + n2; j < n1; j++, p1 += n2)
    {
        tmp2 = pow (tmp, j);
        p1[j] *= tmp2;
        tmp2 *= tmp;
        for (k = j + 1, p2 = p1 + n2 + j; k < n1; k++, p2 += n2)
        {
            /*
            p1[k] *= tmp2;
            *p2 *= tmp2;
            tmp2 *= tmp;
            */

            modmul3 (&tmp2, &tmp, p1 + k, &tmp2, p2, &tmp2);
        }
        for (; k < n2; k++)
        {
            /*
            p1[k] *= tmp2;
            tmp2 *= tmp;
            */

            modmul2 (&tmp2, &tmp, p1 + k, &tmp2);
        }
        tmp *= w;
    }

    // last do n1 transforms of length n2

    // init table

    if (n2 != n1)
    {
        // n2 = 2 * n1
        for (k = n1; k--;)
            wtable[2 * k] = wtable[k];
        tmp2 = pow (w, nn / n2);
        tmp = tmp2 * tmp2;
        k = 1;
        if (n2 > 2)
        {
            tmp3 = tmp2 * tmp;
            tmp4 = tmp3 * tmp;
            tmp5 = tmp4 * tmp2;
            for (; k < n2 - 4; k += 6)
            {
                wtable[k] = tmp2;
                wtable[k + 2] = tmp3;
                wtable[k + 4] = tmp4;

                modmul3 (&tmp2, &tmp5, &tmp3, &tmp5, &tmp4, &tmp5);
            }
        }
        for (; k < n2; k += 2)
        {
            wtable[k] = tmp2;
            tmp2 *= tmp;
        }
    }

    if (NProcessors <= 1 || nn <= USE_SINGLE_FNT)
        for (k = 0, p1 = data; k < n1; k++, p1 += n2)
            tablefnt (p1, wtable, 0, n2, 0);
    else
        dispatchfnt (tablefnt, n1, data, wtable, 0, n2, 0);

    delete[] ptable;
    delete[] wtable;
}

void tablesixstepfnttrans (modint data[], modint pr, int isign, size_t nn)
{
    size_t n2 = (nn & -nn), j, k, s;
    modint w, ww, w1, w2, w3, *p1, *p2, *p3, tmp, tmp2, tmp3, *d, t;

    if (nn < 2) return;

    if (nn == n2)
    {
        // Transform length is a power of two
        tablesixstepfnttrans2 (data, pr, isign, nn);
        return;
    }

    // Transform length is three times a power of two

    if (isign > 0)
        w = pow (pr, modint::modulus - 1 - (modint::modulus - 1) / nn);
    else
        w = pow (pr, (modint::modulus - 1) / nn);

    ww = w * w;

    w3 = pow (w, n2);                   // 3rd root of unity
    w1 = -modint (3) / modint (2);
    w2 = w3 + modint (1) / modint (2);

    s = min (n2, Cachemaxblocksize / 4);
    d = new modint[3 * s];

    tmp = tmp2 = 1;
    for (k = 0; k < n2; k += s)
    {
        p1 = d;
        p2 = p1 + s;
        p3 = p2 + s;
        moveraw (p1, data + k, s);                      // Cache in
        moveraw (p2, data + k + n2, s);
        moveraw (p3, data + k + 2 * n2, s);
        for (j = 0; j < s; j++, p1++, p2++, p3++)
        {
            t = *p2 + *p3;                              // Transform columns
            *p3 = *p2 - *p3;
            *p1 += t;
            tmp3 = tmp;
            modmul3 (&t, &w1, p3, &w2, &tmp, &w);
            t += *p1;
            *p2 = t + *p3;
            *p3 = t - *p3;
            modmul3 (p2, &tmp3, p3, &tmp2, &tmp2, &ww); // Multiply
        }
        p1 = d;
        p2 = p1 + s;
        p3 = p2 + s;
        moveraw (data + k, p1, s);                      // Cache out
        moveraw (data + k + n2, p2, s);
        moveraw (data + k + 2 * n2, p3, s);
    }

    delete[] d;

    tablesixstepfnttrans2 (data, pr, isign, n2);        // Transform rows
    tablesixstepfnttrans2 (data + n2, pr, isign, n2);
    tablesixstepfnttrans2 (data + 2 * n2, pr, isign, n2);
}

void itablesixstepfnttrans2 (modint data[], modint pr, int isign, size_t nn, size_t e)
{
    size_t n1, n2, j, k;
    modint w, tmp, tmp2, tmp3, tmp4, tmp5, *p1, *p2, inn;

    if (nn < 2) return;

    for (n1 = 1, n2 = 0; n1 < nn; n1 <<= 1, n2++);
    n1 = n2 >> 1;
    n2 -= n1;

    n1 = 1 << n1;
    n2 = 1 << n2;

    // n2 >= n1

    modint *wtable = new modint[n2];
    size_t *ptable = new size_t[n1];

    if (isign > 0)
        w = pow (pr, modint::modulus - 1 - (modint::modulus - 1) / nn);
    else
        w = pow (pr, (modint::modulus - 1) / nn);

    // treat the input data as a n1 x n2 matrix

    // first do n1 transforms of length n2

    // init table

    tmp = pow (w, nn / n2);
    tmp2 = 1;
    k = 0;
    if (n2 > 2)
    {
        tmp3 = tmp;
        tmp4 = tmp * tmp;
        tmp5 = tmp * tmp4;
        for (; k < n2 - 2; k += 3)
        {
            wtable[k] = tmp2;
            wtable[k + 1] = tmp3;
            wtable[k + 2] = tmp4;

            modmul3 (&tmp2, &tmp5, &tmp3, &tmp5, &tmp4, &tmp5);
        }
    }
    for (; k < n2; k++)
    {
        wtable[k] = tmp2;
        tmp2 *= tmp;
    }

    if (NProcessors <= 1 || nn <= USE_SINGLE_FNT)
        for (k = 0, p1 = data; k < n1; k++, p1 += n2)
            itablefnt (p1, wtable, 0, n2, 0);
    else
        dispatchfnt (itablefnt, n1, data, wtable, 0, n2, 0);

    // then multiply the matrix A_jk by exp(isign * -2 pi i j k / nn) / nn

    tmp = 1;
    inn = modint (1) / modint (nn * e);
    for (j = 0, p1 = data; j < n1; j++, p1 += n2)
    {
        tmp2 = pow (tmp, j) * inn;
        p1[j] *= tmp2;
        tmp2 *= tmp;
        for (k = j + 1, p2 = p1 + n2 + j; k < n1; k++, p2 += n2)
        {
            /*
            p1[k] *= tmp2;
            *p2 *= tmp2;
            tmp2 *= tmp;
            */

            modmul3 (&tmp2, &tmp, p1 + k, &tmp2, p2, &tmp2);
        }
        for (; k < n2; k++)
        {
            /*
            p1[k] *= tmp2;
            tmp2 *= tmp;
            */

            modmul2 (&tmp2, &tmp, p1 + k, &tmp2);
        }
        tmp *= w;
    }

    // transpose the matrix

    transpose (data, n1, n2);

    // then do n2 transforms of length n1

    // init table

    if (n2 != n1)
        // n2 = 2 * n1
        for (k = 0; k < n1; k++)
            wtable[k] = wtable[2 * k];

    initscrambletable (ptable, n1);

    if (NProcessors <= 1 || nn <= USE_SINGLE_FNT)
        for (k = 0, p1 = data; k < n2; k++, p1 += n1)
            itablefnt (p1, wtable, ptable, n1);
    else
        dispatchfnt (itablefnt, n2, data, wtable, ptable, n1);

    // last transpose the matrix

    transpose (data, n2, n1);

    delete[] ptable;
    delete[] wtable;
}

void itablesixstepfnttrans (modint data[], modint pr, int isign, size_t nn)
{
    size_t n2 = (nn & -nn), j, k, s;
    modint w, ww, w1, w2, w3, *p1, *p2, *p3, tmp, tmp2, *d, t;

    if (nn < 2) return;

    if (nn == n2)
    {
        // Transform length is a power of two
        itablesixstepfnttrans2 (data, pr, isign, nn);
        return;
    }

    // Transform length is three times a power of two

    if (isign > 0)
        w = pow (pr, modint::modulus - 1 - (modint::modulus - 1) / nn);
    else
        w = pow (pr, (modint::modulus - 1) / nn);

    itablesixstepfnttrans2 (data, pr, isign, n2, 3);    // Transform rows
    itablesixstepfnttrans2 (data + n2, pr, isign, n2, 3);
    itablesixstepfnttrans2 (data + 2 * n2, pr, isign, n2, 3);

    ww = w * w;

    w3 = pow (w, n2);                   // 3rd root of unity
    w1 = -modint (3) / modint (2);
    w2 = w3 + modint (1) / modint (2);

    s = min (n2, Cachemaxblocksize / 4);
    d = new modint[3 * s];

    tmp = tmp2 = 1;
    for (k = 0; k < n2; k += s)
    {
        p1 = d;
        p2 = p1 + s;
        p3 = p2 + s;
        moveraw (p1, data + k, s);                      // Cache in
        moveraw (p2, data + k + n2, s);
        moveraw (p3, data + k + 2 * n2, s);
        for (j = 0; j < s; j++, p1++, p2++, p3++)
        {
            modmul3 (p2, &tmp, p3, &tmp2, &tmp, &w);    // Multiply
            t = *p2 + *p3;                              // Transform columns
            *p3 = *p2 - *p3;
            *p1 += t;
            modmul3 (&t, &w1, p3, &w2, &tmp2, &ww);
            t += *p1;
            *p2 = t + *p3;
            *p3 = t - *p3;
        }
        p1 = d;
        p2 = p1 + s;
        p3 = p2 + s;
        moveraw (data + k, p1, s);                      // Cache out
        moveraw (data + k + n2, p2, s);
        moveraw (data + k + 2 * n2, p3, s);
    }

    delete[] d;
}
