数学-FFT

问题类型

有时候我们需要将两个多项式乘起来,如果我们采用类似高精度乘法的算法,那么时间复杂度是O(n ^ 2)的。
那么我们有没有更加优秀的时间复杂度呢?
答案是有的。我们可以用FFT来完成多项式乘法的任务,时间复杂度是O(n * logn)的。

解决方案

前置技能

多项式表达法

多项式有常见的2种表达方式:

  • 系数表达法
    所谓系数表达法,就是用 F(x) = a0 x^0 + a1 x^1 + ··· + an x^n
    这时候如果我们要实现多项式乘法 F[i + j] = sigma(G[i] * H[j]) (用F[i],G[i],H[i]表示多项式的i次方系数,G和H相乘得到F)
    所以这是O(n ^ 2)的时间复杂度。

  • 点值表达法
    点值表示法,就是用n+1个点对 (xi, yi)来表示一个多项式(可以由高斯消元唯一确定一个n次多项式)
    此时如果我们相乘很显然只要O(n)的时间复杂度即可了

复数

有时我们会用到对-1开根号的结果,我们把这个结果表示成i。
我们用a+bi的形式来表示一个复数,那么这个复数可以在坐标轴上表示为(a, b),其中a是复数的实部,b为复数的虚部。
(a + bi) * (c + di) = (ac - bd) + (ad + bc)i,我们可以发现两个虚数相乘,就是辐角相加,模长相乘。
image1
我们称wn为n次单位根,它的n次方即为1。用w[n][k]表示wn的k次方。显然我们可以在一个单位圆上找到w[n][1],···,w[n][k]。
单位根有一些很显然的性质:w[n][k] = w[2n][2k], w[2n][k] = -w[2n][k + n]
image2

FFT

而FFT的核心内容就是DFT(将系数表达法转化为点值表达法)和IDFT(将点值表达法转化回系数表达法)

DFT

在DFT的过程中,我们已知多项式的系数,要根据复数的单位根求一些点值。
我们发现 F(w[n][k]) = a0 * w[n][k] ^ 0 + a1 * w[n][k] ^ 1 + ··· + an-1 * w[n][k] ^ n-1
F(w[n][k]) = L(w[n][k]) + w[n][k] * R(w[n][k]) (L表示以偶次幂的系数组成的多项式,R表示以奇次幂的系数组成的多项式,而R这个多项式可以提取公因子w[n][k])
而L(w[n][k])可以看成是一个变量是w[n][k] ^ 2的多项式G(w[n][k] ^ 2),同理R(w[n][k])也可以。
所以我们得到 F(w[n][k]) = G(w[n / 2][k]) + w[n][k] * H(w[n / 2][k])
而G和H这两个多项式都是n / 2次的,所以我们可以通过递归来处理。

有时候我们觉得递归版的常数太大了,我们就要用到非递归版的。
我们每次在往下递归的时候,把系数分成了偶次幂和奇次幂两块。
所以在第i层往下递归时,放在左边的就是二进制上第i位为0的,放在右边的是二进制上第i位为1的。
所以每个位置递归到最后的值,就是原来坐标把二进制翻转以后得到的坐标上的值。
image3
所以我们可以预处理出Rev数组,表示一个数二进制翻转以后得到的数。
然后我们只要枚举层数,枚举起始位置,枚举要合并的位置即可。
总时间复杂度还是O(n * logn)的。

IDFT

IDFT的过程,就是把点值重新转化为系数表达式。
有人证出来插值只要将所有w[n][k]换成w[n][k + n / 2],也就是所有的虚部取相反数,再将最终结果除以len就行了。
所以这部分的过程和DFT的过程相类似。

代码演示

递归版

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void FFT(Complex *T, int n, int op)
{
if (n == 1) return;
Complex L[n >> 1], R[n >> 1];
Rep(i, 0, n >> 1) L[i] = T[i << 1], R[i] = T[i << 1 | 1];
FFT(L, n >> 1, op); FFT(R, n >> 1, op);
Complex wn = Complex(cos(2.0 * pi / n), sin(2.0 * pi / n) * op); Complex w = Complex(1, 0);
Rep(i, 0, n >> 1)
{
T[i] = L[i] + R[i] * w, T[i + (n >> 1)] = L[i] - R[i] * w;
w = w * wn;
}
}
int main()
{
int n = read(), m = read();
rep(i, 0, n) X[i] = read();
rep(i, 0, m) Y[i] = read();
int t; for (t = 1; t <= n + m; t <<= 1);
FFT(X, t, 1); FFT(Y, t, 1);
Rep(i, 0, t) X[i] = X[i] * Y[i];
FFT(X, t, -1);
rep(i, 0, n + m) printf("%d ", int(X[i].real / t + 0.5));
return 0;
}

非递归版

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
void FFT(Complex *T, int n, int op)
{
Rep(i, 0, n) if (i < Rev[i]) swap(T[i], T[Rev[i]]);
for (int Size = 1; Size < n; Size <<= 1)
{
Complex wn = Complex(cos(pi / Size), sin(pi / Size) * op);
for (int L = 0; L < n; L += (Size << 1))
{
Complex w = Complex(1, 0);
Rep(R, L, L + Size)
{
Complex x = T[R], y = w * T[R + Size];
T[R] = x + y; T[R + Size] = x - y;
w = w * wn;
}
}
}
}
int main()
{
int n = read(), m = read();
rep(i, 0, n) X[i] = read();
rep(i, 0, m) Y[i] = read();
int t, l = 0; for (t = 1; t <= n + m; t <<= 1) l++;
Rep(i, 0, t) Rev[i] = (Rev[i >> 1] >> 1) + ((i & 1) << l - 1);
FFT(X, t, 1); FFT(Y, t, 1);
Rep(i, 0, t) X[i] = X[i] * Y[i];
FFT(X, t, -1);
rep(i, 0, n + m) printf("%d ", int(X[i].real / t + 0.5));
return 0;
}
文章目录
  1. 1. 问题类型
  2. 2. 解决方案
    1. 2.1. 前置技能
      1. 2.1.1. 多项式表达法
      2. 2.1.2. 复数
    2. 2.2. FFT
      1. 2.2.1. DFT
      2. 2.2.2. IDFT
  3. 3. 代码演示
    1. 3.1. 递归版
    2. 3.2. 非递归版