For the last few days I've been trying to help Mr. Spektre who, due to compatibility issues, had to write his own Number Theoretic Transform for FFT Multiplication.
Modular arithmetics and NTT (finite field DFT) optimizations
He has one that works just fine, but he had been wondering if there were any ways to speed it up. One idea that came to mind was to use Montgomery Multiplication in order to avoid the excessive divisions. I've used it in the past, but for some reason I can't get it to work here, and I'm not sure if it's an issue with the Montgomery Multiplication or the NTT.
It uses 32 bit word size, so the reduction is also 2^32 and the prime modulus is 3221225473. Using the Ext. Euclidean Algorithm, I found the inverses to be:
2^32 * 2415919104 = (3221225473 * 3221225471) + 1
Below is the code I'm working on, with the main function that calls it.
NOTE: I'm not worrying about the Inverse Transform at this time, as there's not point if the regular Transform doesn't work at all.
#include <string.h>
#ifndef uint32
#define uint32 unsigned long int
#endif
#ifndef uint64
#define uint64 unsigned long long int
#endif
class montgom_ntt // number theoretic transform
{
public:
montgom_ntt()
{
r = 0; L = 0;
W = 0, N = 0;
}
// main interface
void NTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast NTT(uint32 src[n])
private:
bool init(uint32 n); // init r,L,p,W,iW,rN
void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])
void NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])
void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);
// uint32 arithmetics
public:
uint32 montgom_in(uint32 n);
uint32 montgom_out(uint32 n);
void montgom_in_arr(uint32* dst, const uint32* src, uint32 n);
void montgom_out_arr(uint32* dst, const uint32* src, uint32 n);
private:
// modular arithmetics
inline uint32 modadd(uint32 a, uint32 b);
inline uint32 modsub(uint32 a, uint32 b);
inline uint32 modmul(uint32 a, uint32 b);
inline uint32 modpow(uint32 a, uint32 b);
uint32 r, L, N, W;
const uint32 p = 0xC0000001;
const uint64 px = 0xC0000001;
};
//---------------------------------------------------------------------------
bool montgom_ntt::init(uint32 n)
{
// (max(src[])^2)*n < p else NTT overflow can ocur !!!
r = 2;
if ((n < 2) || (n > 0x10000000))
{
r = 0; L = 0; W = 0; // p = 0;
iW = 0; rN = 0; N = 0;
return false;
}
L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit
N = n; // size of vectors [uint32s]
W = modpow(r, L); // Wn for NTT
W = montgom_in(W);
return true;
}
//---------------------------------------------------------------------------
void montgom_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
if (n > 0)
{
init(n);
}
NTT_fast(dst, src, N, W);
}
//---------------------------------------------------------------------------
void montgom_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
if (n > 1)
{
if (dst != src)
{
NTT_calc(dst, src, n, w);
}
else
{
uint32* temp = new uint32[n];
memcpy(temp, src, sizeof(uint32) * n);
NTT_calc(dst, temp, n, w);
delete[] temp;
}
}
else if (n == 1)
{
dst[0] = src[0];
}
}
void montgom_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
if (n > 1)
{
uint32 i, j, a0, a1,
n2 = n >> 1,
w2 = modmul(w, w);
// reorder even,odd
for (i = 0, j = 0; i < n2; i++, j += 2)
{
dst[i] = src[j];
}
for (j = 1; i < n; i++, j += 2)
{
dst[i] = src[j];
}
// recursion
if (n2 > 1)
{
NTT_calc(src, dst, n2, w2); // even
NTT_calc(src + n2, dst + n2, n2, w2); // odd
}
else if (n2 == 1)
{
src[0] = dst[0];
src[1] = dst[1];
}
// restore results
w2 = 1, i = 0, j = n2;
a0 = src[i];
a1 = src[j];
dst[i] = modadd(a0, a1);
dst[j] = modsub(a0, a1);
while (++i < n2)
{
w2 = modmul(w2, w);
j++;
a0 = src[i];
a1 = modmul(src[j], w2);
dst[i] = modadd(a0, a1);
dst[j] = modsub(a0, a1);
}
}
}
//---------------------------------------------------------------------------
void montgom_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
uint32 i, j, wj, wi, a,
n2 = n >> 1;
for (wj = 1, j = 0; j < n; j++)
{
a = 0;
for (wi = 1, i = 0; i < n; i++)
{
a = modadd(a, modmul(wi, src[i]));
wi = modmul(wi, wj);
}
dst[j] = a;
wj = modmul(wj, w);
}
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_in(uint32 n)
{
uint64 N = n;
N = (N << 32) % px;
return N;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_out(uint32 n)
{
const uint64 C = 0x90000000;
uint64 N = n;
N *= C;
N %= px;
return N;
}
//---------------------------------------------------------------------------
void montgom_ntt::montgom_in_arr(uint32* dst, const uint32* src, uint32 n)
{
uint32 I = 0;
do
{
dst[I] = montgom_in(src[I]);
} while (++I < n);
}
//---------------------------------------------------------------------------
void montgom_ntt::montgom_out_arr(uint32* dst, const uint32* src, uint32 n)
{
uint32 I = 0;
do
{
dst[I] = montgom_out(src[I]);
} while (++I < n);
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modadd(uint32 a, uint32 b)
{
uint32 n = a + b;
if (n < a)
{
n -= p;
}
else if (n >= p)
{
n -= p;
}
return n;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modsub(uint32 a, uint32 b)
{
uint32 d;
d = a - b;
if(a < b)
{
d += p;
d = (a + p) - b;
}
return d;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
uint64 A(a), B(b), C;
uint32 R;
A *= B;
C = A & 0xFFFFFFFF;
C *= 0xBFFFFFFF;
C = (C & 0xFFFFFFFF) * px;
C += A;
R = (C >> 32);
if(C < A)
{
R -= p;
}
if(R >= p)
{
R -= p;
}
return R;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modpow(uint32 a, uint32 b)
{
//*
uint64 D, M, A;
P = p; A = a;
M = 0llu - (b & 1);
D = (M & A) | ((~M) & 1);
while ((b >>= 1) != 0)
{
A = (A * A) % P;
if ((b & 1) == 1)
{
D = (D * A) % P;
}
}
return (uint32)D;
}
and here's main
void main()
{
montgom_ntt F;
uint32 Tran[8];
uint32 Arr[8] =
{
0x2923, 0xbe84,
0xe16c, 0xd6ae,
0, 0, 0, 0
};
F.montgom_in_arr(Arr1, Arr1, Len);
F.NTT(Tran, Arr, Len);
F.montgom_out_arr(Tran, Tran, Len);
}
I get the feeling it's something really simple, but I can't figure out what it is. Thanks for any help you guys can provide!
[Edit] So in an effort to rule it out, I modified the modmul function so that it converted its input from Montgomery form to regular form, performed the standard (A * B) % p, and then converted it back into Montgomery form and I still got the same, wrong answer. This makes me think that the problem is the conversion to and from Montgomery form, but I have no idea what I did wrong.
uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
uint64 A, B, C;
A = montgom_out(a);
B = montgom_out(b);
C = (A * B) % px;
return montgom_in(C);
/*
uint64 A(a), B(b), C;
uint32 R;
A *= B;
C = A & 0xFFFFFFFF;
C *= 0xBFFFFFFF;
C = (C & 0xFFFFFFFF) * px;
C += A;
R = (C >> 32);
if(C < A)
{
R -= p;
}
if(R >= p)
{
R -= p;
}
return R;
*/
}
User contributions licensed under CC BY-SA 3.0