Introduction
快速傅里叶变换(Fast Fourier Transform,FFT)
是一种可在 (O(n log n)) 时间内完成的离散傅里叶变换 (Discrete Fourier Transform,DFT)
的算法,用来实现将信号从原始域(通常是时间或空间)到频域的互相转化。
FFT 在算法竞赛中主要用来加速多项式乘法(循环卷积)。
多项式
形如
的式子称为 (x) 的 (n - 1) 次多项式,其中 (a_0, a_1, dots, a_{n - 1}) 称为多项式系数,(n-1) 称为多项式的次数,记为 (deg A(x)) 或 (deg A)。
点值
(n - 1) 次多项式 (A(x)) 在 (x = m) 处的点值
多项式乘法
记 (A(x) imes B(x)) 表示多项式 (A(x), B(x)) 做多项式乘法,可以简写为 (A(x)cdot B(x)) 或 (A(x)B(x))。
多项式乘法
用系数关系可以表示为
其中 (deg C = deg A + deg B)。
易证它们的点值满足如下关系
循环卷积
记 (operatorname{conv}(A, B, n)) 表示多项式 (A(x), B(x)) 做长度为 (n) 的循环卷积。
循环卷积
系数关系表示为
其中 (deg C = n - 1)。
容易发现,当 (n > deg A + deg B) 时,该运算等价于多项式乘法。
DFT
离散傅里叶变换(Discrete Fourier Transform, DFT)
将多项式 (A(x)=sum_{k=0}^{n-1}a_kx^k) 转换为一些特殊的点值。
记 (n) 次单位复根
(DFT(A)) 就是要计算点值 (A(omega_n^k), k = 0, 1, 2, dots, n-1)。
单位根自带的循环特性使得循环卷积 (C(x) = operatorname{conv}(A, B, n)) 的点值也满足:
IDFT
IDFT 是 DFT 的逆变换。
首先,用等比数列求和易证:
考虑循环卷积 (C(x) = operatorname{conv}(A, B, n)) 的系数表示
设多项式
只要计算 (DFT(C')) 即可得到 (C(x)) 的系数,于是我们用 DFT 完成了逆变换 IDFT。
用两次 DFT 和一次 IDFT就可以计算 (operatorname{conv}(A, B, n))。
暴力的复杂度是 (O(n^2)),此处不赘述。
FFT
现在尝试将 DFT 问题分解以优化时间复杂度。
本部分认为 (n = deg A + 1) 为 (2) 的整数次幂。对于更一般的情况,暂不考虑。
DIF
将序列 (a_i) 分成左右两半。
进一步,将 (A(omega_{n}^r)) 按奇偶分类:
设
我们只需要求出 (P(omega_{n/2}^r)) 和 (Q(omega_{n/2}^r)) ,即求解规模为原来一半的两个子问题 (DFT(P), DFT(Q)),就能在 (O(n)) 时间内计算出 (DFT(A))。
DIT
在算法竞赛中这种方法更常见。
注意到在 DIF
中我们最后将 (A(omega_n^r)) 奇偶分类求解,那不妨思考将序列 (a_k) 按奇偶分类。
设
则
所以
将 (A(omega_n^k)) 再细分为左右两半,这里运用了等式 (omega_{n/2}^k = omega_{n/2}^{k + n/2}) 和 (omega_n^k+omega_n{k+n/2} = 0) :
我们只需要求出 (A_0(omega_{n/2}^k)) 和 (A_1(omega_{n/2}^k)) ,即求解规模为原来一半的两个子问题 (DFT(A_0), DFT(A_1)),就能在 (O(n)) 时间内计算出 (DFT(A))。
Complexity
设次数为 (n - 1) 的多项式做 DFT 的时间复杂度为 (T(n)),则
根据主定理
Implementation
上述两种计算方式均可以使用递归实现,这里直接给出代码,不再赘述。
DIF
const double PI = acos(-1.0);
void dft(std::vector<Complex> &a) {
int n = a.size(), m = n >> 1;
if (n == 1) return;
std::vector<Complex> p(m), q(m);
for (int i = 0; i < m; i++) {
p[i] = a[i] + a[i + m];
q[i] = (a[i] - a[i + m]) * Complex(cos(2 * PI * i / n), sin(2 * PI * i / n));
}
dft(p), dft(q);
for (int i = 0; i < m; i++)
a[i << 1] = p[i], a[i << 1 | 1] = q[i];
}
void idft(std::vector<Complex> &a) {
dft(a);
for (auto &v: a) v.a /= a.size(), v.b /= a.size();
std::reverse(a.begin() + 1, a.end());
}
DIT
const double PI = acos(-1.0);
void dft(std::vector<Complex> &a) {
int n = a.size(), m = n >> 1;
if (n == 1) return;
std::vector<Complex> p(m), q(m);
for (int i = 0; i < m; i++) {
p[i] = a[i << 1];
q[i] = a[i << 1 | 1];
}
dft(p), dft(q);
for (int i = 0; i < m; i++) {
Complex &u = p[i], v = Complex(cos(2 * PI * i / n), sin(2 * PI * i / n)) * q[i];
a[i] = u + v, a[i + m] = u - v;
}
}
void idft(std::vector<Complex> &a) {
dft(a);
for (auto &v: a) v.a /= a.size(), v.b /= a.size();
std::reverse(a.begin() + 1, a.end());
}
下面探讨以 非递归方式 实现 DIF
与 DIT
。
DIT
由于 DIT
更易于理解(其实只是资料多),先讲这个。
考虑递归的过程:
发现 (0 ightarrow 3) 只是在重新安排数据位置,并没有修改数据,如果我们能把映射关系找到,那就可以一步到位,直接从 (3) 开始。
设一个数在第 (i) 个阶段 ((0 leq i leq log_2n)) 的位置为 (p_i),相对位置为 (p_i')(相对位置
指它在括号里的位置,例如上面第 (1) 阶段 (a_1) 的相对位置为 (0))。
容易发现
如果将 (p_0) 写成二进制 (overline{b_4b_3b_2b_1b_0})(这里以 (n = 32) 为例),那么单次变化的过程相当于把二进制的后几位向右 rotate
一位,总的变化过程可以描述为:
可以发现,整个过程实际上是在做 reverse
操作!
至此我们找到了映射关系,成功把前面的步骤都砍掉了,只剩回溯,可以改成循环。
void dft(std::vector<Complex> &a) {
int n = a.size();
for (int i = 0, j = 0; i < n; i++) {
if (i > j) std::swap(a[i], a[j]);
for (int k = n >> 1; (j ^= k) < k; k >>= 1);
}
for (int k = 1; k < n; k <<= 1) {
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
auto t = a[i + j + k] * Complex(cos(PI * j / k), sin(PI * j / k));
a[i + j + k] = a[i + j] - t;
a[i + j] = a[i + j] + t;
}
}
}
}
DIF
这里先将递归版DIF的过程简单复述:
第一步:将序列 (a) 对半分
第二步:递归计算 (DFT(p), DFT(q))
第三步:重新安排数据位置
发现回溯的过程(即第三步)实际上也只是在重新安排数据存储的位置,而且是上面 (DIT) 第一步的逆过程,所以就是位翻转的逆过程,所以还是位翻转。
所以最后安排数据位置可以一步搞定,只剩递归压栈的过程,可以改成循环。
void dft(vector<Complex> &a) {
int n = a.size();
for (int k = n >> 1; k; k >>= 1) {
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
auto t = a[i + j + k];
a[i + j + k] = (a[i + j] - t) * Complex(cos(PI * j / k), sin(PI * j / k));
a[i + j] = a[i + j] + t;
}
}
}
for (int i = 0, j = 0; i < n; i++) {
if (i > j) std::swap(a[i], a[j]);
for (int k = n >> 1; (j ^= k) < k; k >>= 1);
}
}
Combination
发现 (DIF) 的最后一步和 (DIT) 的第一步都是位翻转,所以先 (DIF) 再 (DIT),就可以省略位翻转。
完整代码
#include <bits/stdc++.h>
template <class T>
inline void readInt(T &w) {
char c, p = 0;
while (!isdigit(c = getchar())) p = c == '-';
for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
if (p) w = -w;
}
struct Complex {
double a, b; // a + bi
Complex(double a = 0, double b = 0): a(a), b(b) {}
};
inline Complex operator+(const Complex &p, const Complex &q) {
return Complex(p.a + q.a, p.b + q.b);
}
inline Complex operator-(const Complex &p, const Complex &q) {
return Complex(p.a - q.a, p.b - q.b);
}
inline Complex operator*(const Complex &p, const Complex &q) {
return Complex(p.a * q.a - p.b * q.b, p.a * q.b + p.b * q.a);
}
const double PI = acos(-1.0);
void dft(std::vector<Complex> &a) {
int n = a.size();
for (int k = n >> 1; k; k >>= 1) {
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
auto t = a[i + j + k];
a[i + j + k] = (a[i + j] - t) * Complex(cos(PI * j / k), sin(PI * j / k));
a[i + j] = a[i + j] + t;
}
}
}
}
void idft(std::vector<Complex> &a) {
int n = a.size();
for (int k = 1; k < n; k <<= 1) {
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
auto t = a[i + j + k] * Complex(cos(PI * j / k), sin(PI * j / k));
a[i + j + k] = a[i + j] - t;
a[i + j] = a[i + j] + t;
}
}
}
for (auto &v: a) v.a /= a.size(), v.b /= a.size();
std::reverse(a.begin() + 1, a.end());
}
int main() {
int n, m, k;
readInt(n), readInt(m);
k = 1 << std::__lg(n + m) + 1;
std::vector<Complex> a(k), b(k), c(k);
for (int i = 0; i <= n; i++) readInt(a[i].a);
for (int i = 0; i <= m; i++) readInt(b[i].a);
dft(a), dft(b);
for (int i = 0; i < k; i++) c[i] = a[i] * b[i];
idft(c);
for (int i = 0; i <= n + m; i++) printf("%d ", (int)(c[i].a + 0.5));
return 0;
}
Number Theory Transform
如果一个质数存在 (2^n) 次单位根(其中 (n) 最大时的单位根称为原根),那么在这个质数的剩余系下上面的结论依旧成立,可以使用FFT,多称这种FFT为快速数论变换(Number Theory Transform, NTT)
。
常见的质数是 (P = 998244353),它的原根 (g = 3)。
代码(预处理单位根,较好地平衡了代码复杂度和常数且有一定的封装度):
#include <bits/stdc++.h>
template <class T>
inline void readInt(T &w) {
char c, p = 0;
while (!isdigit(c = getchar())) p = c == '-';
for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
if (p) w = -w;
}
template <class T, class... U>
inline void readInt(T &w, U &... a) { readInt(w), readInt(a...); }
constexpr int P(998244353), G(3);
inline void inc(int &x, int y) { (x += y) >= P ? x -= P : 0; }
inline int sum(int x, int y) { return x + y >= P ? x + y - P : x + y; }
inline int sub(int x, int y) { return x - y < 0 ? x - y + P : x - y; }
inline int fpow(int x, int k = P - 2) {
int r = 1;
for (; k; k >>= 1, x = 1LL * x * x % P)
if (k & 1) r = 1LL * r * x % P;
return r;
}
namespace Polynomial {
using Polynom = std::vector<int>;
int n;
std::vector<int> w;
void getOmega(int k) {
w.resize(k);
w[0] = 1;
int base = fpow(G, (P - 1) / (k << 1));
for (int i = 1; i < k; i++) w[i] = 1LL * w[i - 1] * base % P;
}
void dft(Polynom &a) {
for (int k = n >> 1; k; k >>= 1) {
getOmega(k);
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
int y = a[i + j + k];
a[i + j + k] = (1LL * a[i + j] - y + P) * w[j] % P;
inc(a[i + j], y);
}
}
}
}
void idft(Polynom &a) {
for (int k = 1; k < n; k <<= 1) {
getOmega(k);
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
int x = a[i + j], y = 1LL * a[i + j + k] * w[j] % P;
a[i + j] = sum(x, y);
a[i + j + k] = sub(x, y);
}
}
}
int inv = fpow(n);
for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * inv % P;
std::reverse(a.begin() + 1, a.end());
}
} // namespace Polynom
using Polynomial::dft;
using Polynomial::idft;
void poly_multiply(unsigned *A, int n, unsigned *B, int m, unsigned *C) {
int k = Polynomial::n = 1 << std::__lg(n + m) + 1;
std::vector<int> a(k), b(k);
for (int i = 0; i <= n; i++) a[i] = A[i];
for (int i = 0; i <= m; i++) b[i] = B[i];
dft(a), dft(b);
for (int i = 0; i < k; i++) a[i] = 1LL * a[i] * b[i] % P;
idft(a);
for (int i = 0; i <= n + m; i++) C[i] = a[i];
}
int main() {
int n, m, k;
readInt(n, m);
Polynomial::n = k = 1 << std::__lg(n + m) + 1;
std::vector<int> a(k), b(k);
for (int i = 0; i <= n; i++) readInt(a[i]);
for (int i = 0; i <= m; i++) readInt(b[i]);
dft(a), dft(b);
for (int i = 0; i < k; i++) a[i] = 1LL * a[i] * b[i] % P;
idft(a);
for (int i = 0; i <= n + m; i++) printf("%d ", a[i]);
return 0;
}
Bluestein’s Algorithm
上面提到的 FFT 算法虽然限制了 (n) 为 2 的次幂,但在大多数情况下已经足够解决问题。
对于更一般的 (n) 需要用到 Bluestein’s Algorithm
,可以参考2016年国家集训队论文《再探快速傅里叶变换——毛啸》。
后续可能会填这个坑。