NTT w/ Montgomery Multiplication

1

For the last few days I've been trying to help Mr. Spektre who, due to compatibility issues, had to write his own Number Theoretic Transform for FFT Multiplication.
Modular arithmetics and NTT (finite field DFT) optimizations

He has one that works just fine, but he had been wondering if there were any ways to speed it up. One idea that came to mind was to use Montgomery Multiplication in order to avoid the excessive divisions. I've used it in the past, but for some reason I can't get it to work here, and I'm not sure if it's an issue with the Montgomery Multiplication or the NTT.

It uses 32 bit word size, so the reduction is also 2^32 and the prime modulus is 3221225473. Using the Ext. Euclidean Algorithm, I found the inverses to be:
2^32 * 2415919104 = (3221225473 * 3221225471) + 1

Below is the code I'm working on, with the main function that calls it.
NOTE: I'm not worrying about the Inverse Transform at this time, as there's not point if the regular Transform doesn't work at all.

#include <string.h>


#ifndef uint32
#define uint32 unsigned long int
#endif

#ifndef uint64
#define uint64 unsigned long long int
#endif


class montgom_ntt                                   // number theoretic transform
{
public:
    montgom_ntt()
    {
        r = 0; L = 0; 
        W = 0, N = 0;
    }
    // main interface
    void  NTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast  NTT(uint32 src[n])

private:
    bool init(uint32 n);                                     // init r,L,p,W,iW,rN
    void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])

    void  NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])
    void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);

    // uint32 arithmetics
    public:
    uint32 montgom_in(uint32 n);
    uint32 montgom_out(uint32 n);
    void montgom_in_arr(uint32* dst, const uint32* src, uint32 n);
    void montgom_out_arr(uint32* dst, const uint32* src, uint32 n);

    private:

    // modular arithmetics
    inline uint32 modadd(uint32 a, uint32 b);
    inline uint32 modsub(uint32 a, uint32 b);
    inline uint32 modmul(uint32 a, uint32 b);
    inline uint32 modpow(uint32 a, uint32 b);

    uint32 r, L, N, W;

    const uint32 p = 0xC0000001;
    const uint64 px = 0xC0000001;
};

//---------------------------------------------------------------------------
bool montgom_ntt::init(uint32 n)
{
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r = 2;

    if ((n < 2) || (n > 0x10000000))
    {
        r = 0; L = 0; W = 0; // p = 0;
        iW = 0; rN = 0; N = 0;
        return false;
    }
    L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit

    N = n;               // size of vectors [uint32s]
    W = modpow(r, L); // Wn for NTT
    W = montgom_in(W);

    return true;
}

//---------------------------------------------------------------------------
void montgom_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
    if (n > 0)
    {
        init(n);
    }
    NTT_fast(dst, src, N, W);
}


//---------------------------------------------------------------------------
void montgom_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if (n > 1)
    {
        if (dst != src)
        {
            NTT_calc(dst, src, n, w);
        }
        else
        {
            uint32* temp = new uint32[n];
            memcpy(temp, src, sizeof(uint32) * n);
            NTT_calc(dst, temp, n, w);

            delete[] temp;
        }
    }
    else if (n == 1)
    {
        dst[0] = src[0];
    }
}


void montgom_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if (n > 1)
    {
        uint32 i, j, a0, a1,
            n2 = n >> 1,
            w2 = modmul(w, w);

        // reorder even,odd
        for (i = 0, j = 0; i < n2; i++, j += 2)
        {
            dst[i] = src[j];
        }
        for (j = 1; i < n; i++, j += 2)
        {
            dst[i] = src[j];
        }
        // recursion
        if (n2 > 1)
        {
            NTT_calc(src, dst, n2, w2);  // even
            NTT_calc(src + n2, dst + n2, n2, w2);  // odd
        }
        else if (n2 == 1)
        {
            src[0] = dst[0];
            src[1] = dst[1];
        }

        // restore results

        w2 = 1, i = 0, j = n2;
        a0 = src[i];
        a1 = src[j];

        dst[i] = modadd(a0, a1);
        dst[j] = modsub(a0, a1);
        while (++i < n2)
        {
            w2 = modmul(w2, w);
            j++;
            a0 = src[i];
            a1 = modmul(src[j], w2);
            dst[i] = modadd(a0, a1);
            dst[j] = modsub(a0, a1);
        }
    }
}

//---------------------------------------------------------------------------
void montgom_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    uint32 i, j, wj, wi, a,
        n2 = n >> 1;
    for (wj = 1, j = 0; j < n; j++)
    {
        a = 0;
        for (wi = 1, i = 0; i < n; i++)
        {
            a = modadd(a, modmul(wi, src[i]));
            wi = modmul(wi, wj);
        }
        dst[j] = a;
        wj = modmul(wj, w);
    }
}


//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_in(uint32 n)
{
    uint64 N = n;
    N = (N << 32) % px;
    return N;
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_out(uint32 n)
{
    const uint64 C = 0x90000000;
    uint64 N = n;
    N *= C;
    N %= px;

    return N;
}

//---------------------------------------------------------------------------
void montgom_ntt::montgom_in_arr(uint32* dst, const uint32* src, uint32 n)
{
    uint32 I = 0;

    do
    {
        dst[I] = montgom_in(src[I]);
    } while (++I < n);
}

//---------------------------------------------------------------------------
void montgom_ntt::montgom_out_arr(uint32* dst, const uint32* src, uint32 n)
{
    uint32 I = 0;

    do
    {
        dst[I] = montgom_out(src[I]);
    } while (++I < n);
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modadd(uint32 a, uint32 b)
{
    uint32 n = a + b;
    if (n < a)
    {
        n -= p;
    }
    else if (n >= p)
    {
        n -= p;
    }
    return n;
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modsub(uint32 a, uint32 b)
{
    uint32 d;

    d = a - b;
    if(a < b)
    {
        d += p;
        d = (a + p) - b;
    }
    return d;
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
    uint64 A(a), B(b), C;
    uint32 R;

    A *= B;
    C = A & 0xFFFFFFFF;
    C *= 0xBFFFFFFF;

    C = (C & 0xFFFFFFFF) * px;
    C += A;
    R = (C >> 32);
    if(C < A)
    {
        R -= p;
    }
    if(R >= p) 
    {
        R -= p;
    }   
    return R;

}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modpow(uint32 a, uint32 b)
{
    //*
    uint64 D, M, A;

    P = p; A = a;
    M = 0llu - (b & 1);
    D = (M & A) | ((~M) & 1);

    while ((b >>= 1) != 0)
    {
        A = (A * A) % P;

        if ((b & 1) == 1)
        {
            D = (D * A) % P;
        }
    }
    return (uint32)D;
}

and here's main

void main()
{
    montgom_ntt F;

    uint32 Tran[8];
    uint32 Arr[8] = 
    {
        0x2923, 0xbe84,
        0xe16c, 0xd6ae,
        0, 0, 0, 0
    };

    F.montgom_in_arr(Arr1, Arr1, Len);
    F.NTT(Tran, Arr, Len);
    F.montgom_out_arr(Tran, Tran, Len);

}

I get the feeling it's something really simple, but I can't figure out what it is. Thanks for any help you guys can provide!

[Edit] So in an effort to rule it out, I modified the modmul function so that it converted its input from Montgomery form to regular form, performed the standard (A * B) % p, and then converted it back into Montgomery form and I still got the same, wrong answer. This makes me think that the problem is the conversion to and from Montgomery form, but I have no idea what I did wrong.

uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
    uint64 A, B, C;

    A = montgom_out(a);
    B = montgom_out(b);
    C = (A * B) % px;
    return montgom_in(C);        

    /*
    uint64 A(a), B(b), C;
    uint32 R;

    A *= B;
    C = A & 0xFFFFFFFF;
    C *= 0xBFFFFFFF;

    C = (C & 0xFFFFFFFF) * px;
    C += A;
    R = (C >> 32);
    if(C < A)
    {
        R -= p;
    }
    if(R >= p) 
    {
        R -= p;
    }   
    return R;
    */
}
c++
algorithm
multiplication
modular-arithmetic
asked on Stack Overflow Jun 18, 2014 by Mandalf The Beige • edited May 23, 2017 by Community

0 Answers

Nobody has answered this question yet.


User contributions licensed under CC BY-SA 3.0