Recently, I saw this codeforce problem https://codeforces.com/problemset/problem/1257/G, so I start to figure out how FFT works.

Iterative FFT

Let be the original sequence and is the sequence after divide and conquer,

Then we can observe that:

Let’s initialize index for and index for , note that we can accelerate the procedure by incremental and in-place fashion, that is

Given the fact that:

, we can actually exploit the previous step result.

Let us rewrite it to: , then we can plug into the update step.

Though seems crazy, it can actually be implemented fast by bitwise operations in reverse_add below.

Horay! Let’s see the code

int reverse_add(int x, int bit_length)
{
    // try to find the most first bit that is 0
    for(int l = 1 << bit_length; (x ^= l) < l; l >>= 1);
    return x;
}
void bit_reverse(int n, complex_t *x)
{
    int bit_length = (int) log2(n);
    for(int i = 0, j = 0; i != n; ++i)
    {
        if(i > j) swap(x[i], x[j]);
        j = reverse_add(j, bit_length);
    }
}

Let us try this example:

where .

And a polynomial ( I use the vector notation introduced by MIT open course):

Let’s see the computational graph:

fft

Code:

void transform(int n, complex_t *x, complex_t *w)
{
    // Prepare the leaf nodes ( you can see this as a bottom-up approach )
    bit_reverse(n, x);
    for(int stride = 2; stride <= n; stride <<= 1)
    {
 
        int half_stride = stride >> 1;
        // the start index of a stride region
        for(int start = 0; start < n; start += stride)
        {
            // this loop will finish up [start, start + stride] region
            for(int k = 0; k < half_stride; ++k)
            {
                complex_t z = x[start + half_stride + k] * w[n / stride * k];
                // Use the phase property (see below)
                x[start + half_stride + k] = x[start + k] - z;
                x[start + k] = x[start + k] + z;
            }
        }
    }
}

Note, in the code,

and the phase property:

Recursively FFT

Here is the pseudocode:

fft(<a0, a1, a2, a3 ... , an-1>){
    result = [None] * n
    even = fft(<a0, a2, ... an-2>) // even[i] == A(w^{i}_{n/2})
    assert(len(even) == n/2)
    odd = fft(<a1, a3, ..., an-1>)
    assert(len(odd) == n/2)
    
    for(int i = 0; i < n; i++){
        // A(w^i_n) = A_even(w^{2i}_n) + w^i_n * A_odd(w^{2i}_n)
        // See the fact below.
        result[i] = even[i % (n/2)] + w^i_n * odd[i % (n/2)]
    }

    return result
}

This is the code written by miskoo (See reference):

void fft(int n, complex<double>* buffer, int offset, int step, complex<double>* epsilon)
{
    if(n == 1) return;
    int m = n >> 1;
    fft(m, buffer, offset, step << 1, epsilon);
    fft(m, buffer, offset + step, step << 1, epsilon);
    for(int k = 0; k != m; ++k)
    {
        int pos = 2 * step * k;
        temp[k] = buffer[pos + offset] + epsilon[k * step] * buffer[pos + offset + step];
        temp[k + m] = buffer[pos + offset] - epsilon[k * step] * buffer[pos + offset + step];
    }
 
    for(int i = 0; i != n; ++i)
        buffer[i * step + offset] = temp[i];
}

For some people, this code is hard to understand at the first glance. So let us visualize it!:

recurseive-fft

Note that for simplicity, I denote:

Number Theoretic Transform (NTT)

Procedure of NTT

  • Let (must be the power of 2) be the transformed size of an input vector.
  • Choose a prime number in the form of , where . By Dirichilet prime theorem, you are guaranteed to find a such that is a prime.
  • Let be the primitive root of unitiy (or say, primitive root modulo ). (It is guaranteed that you can find a by the property of multiplicative group , though I don’t know why.) Note that: under modulo are unique.
  • Define . By Euler theorem, . Consider for , because , by the property of Primitive root modulo n, we can make sure under modulo are unique. (You can prove that because are unique. Therefore, every time you squre , the size of this set will be divided by 2)

Note: If , then because .

How to handle inverse of a finite field number

The last missing part is: how can we efficiently compute (that is, the inverse of )? In fact, we can use Extended Euclidean algorithm. Remember you can compute the by

with initial value and (from Wikipedia).

Let us first see why we can definitely find the inverse of a number. Assume , by Bezout’s identity, we are guaranteed to find integers , such that:

Therefore, if we can calculate , it will be the multiplicative inverse of . In our case, is and is (which are coprime: .

Extended Euclidean algorithm

(Mainly copied from Wikipedia, but add some comments of myself)

As the name shows, this is an extension of your high school Euclidean algorithm:

with , terminate at step when , then , will satisfy .

Proof: We claim that for each step,

Base case:

Therefore,

Assume , our claim is correct.

Induction step: by ’s step of Euclid algorithm:

Thus, at the ’s step (where ), we will get

by the original definition of Euclid algorithm and our claim above.

Alternative way to compute multiplicative inverse

In our case, our goal is to compute:

By Euler’s theorem,

Nevertheless, in practice, we can just swap the result, because if the input polynomial is and:

Our goal is to get (By the definition of inverse FFT):

Recall that:

Observe that:

Therefore,

, which can be implemented by std::swap efficiently! (and no need to explicitly compute )

How to handle negative number

(Thanks to yao11617’s contribution!) In NTT, we are actually perform addition, subtraction and multiplication over (a Finitie Field, where is a prime number). We can leverage the idea of 2’s complement. That is, we can represent the numbeer larger than as negative number.

Precisely,

Let us use several examples to make sure our idea is feasible:

Practically, you can map a negative number to and then restore it after performing NTT by (if the output coefficient is )!

Isn’t it interesting?

Practical Suggestion

  • You can find the useful primitive root of unity table in FFT用到的各種素數 In case of that website crash, I backup that table here:
(prime number) (primitive root)
3 1 1 2
5 1 2 2
17 1 4 3
97 3 5 5
193 3 6 5
257 1 8 3
7681 15 9 17
12289 3 12 11
40961 5 13 3
65537 1 16 3
786433 3 18 10
5767169 11 19 3
7340033 7 20 3
23068673 11 21 3
104857601 25 22 3
167772161 5 25 3
469762049 7 26 3
998244353 119 23 3
1004535809 479 21 3
2013265921 15 27 31
2281701377 17 27 3
3221225473 3 30 5
75161927681 35 31 3
77309411329 9 33 7
206158430209 3 36 22
2061584302081 15 37 7
2748779069441 5 39 3
6597069766657 3 41 5
39582418599937 9 42 5
79164837199873 9 43 5
263882790666241 15 44 7
1231453023109121 35 45 3
1337006139375617 19 46 3
3799912185593857 27 47 5
4222124650659841 15 48 19
7881299347898369 7 50 6
31525197391593473 7 52 3
180143985094819841 5 55 6
1945555039024054273 27 56 5
4179340454199820289 29 57 3

My C++ code

References