FFT&NTT总结
一些概念
(DFT:)离散傅里叶变换( ightarrow O(n^2))计算多项式卷积
(FFT:)快速傅里叶变换( ightarrow O(nlogn))计算多项式卷积
(NTT:)快速数论变换( ightarrow)对(FFT)的常数优化
(MTT:)(NTT)的一些拓展
FFT
多项式&卷积
设(A(x))表示一个(n-1)次多项式
则(A(x)=sum_{i=0}^{n-1}a_ix^i)
而卷积就是两个多项式相乘
如果我们像平常那样暴力乘起来,复杂度是(O(n^2))的
点值表示法
将(n)个点代进(n-1)次多项式(A(x))
则可以确定(n)对((x,y))
而我们有(n)个点也可以确定一个(n-1)次多项式
为什么?很(bu)显(hui)然(zheng)啊。
我们后面的(FFT)的优化就是基于这个来的
复数
定义
我们把形如(z=a+bi)((a,b)均为实数)的数称为复数,其中(a)称为实部,(b)称为虚部,(i)称为虚数单位。(摘自百度百科)其中(i^2=-1)。
而在复平面中,(x)轴代表实数,(y)轴代表虚数(除原点),从原点((0,0))到((a,b))代表复数(a+bi)
模长:((0,0))到((a,b))的距离,即(sqrt {a^2+b^2})
幅角:以逆时针为正方向,(x)轴到已知向量的转角的有向角
运算法则
加减法:
和向量一样,即
((a,b)+(c,d)=(a+b,c+d))
((a,b)-(c,d)=(a-b,c-d))
乘法:
几何意义:复数相乘,模长相乘,幅角相加
代数定义:
单位根
(下文默认(n)为(2)的整数次幂)
在复平面上,以原点为圆心,(1)为半径的圆叫做单位圆。
以原点为起点,圆的(n)等分点为终点,作(n)个向量,设幅角为正且最小的复数向量为(omega _n),称为(n)次单位根。
(n)个向量为(omega_n^1,omega_n^2,omega_n^3...omega_n^{n-1},omega_n^n)((omega_n^n=omega_n^0=1))
如何计算他们的值呢,
可以用欧拉公式:
单位根的幅角为周角的(frac 1n)
代数中,若(z^n=1),我们把(z)称为(n)次单位根
性质
1、(omega_n^k=cos;(k*frac{2pi}{n})+i*sin;(k*frac{2pi}{n}))
2、(omega_n^k=omega_{2n}^{2k})
3、(omega_n^{k+frac n2}=-omega_n^k)
4、(omega_n^0=omega_n^n=1)
快速傅里叶变换
我们前面提过,一个(n-1)次多项式可以用(n)个点唯一确定,
我们可以把(0)~(n-1)次单位根依次带入
但仍然是(O(n^2))啊,因为单位根有很多优秀的性质
所以我们来推一波公式
有
按照下表奇偶性分类
设
则
将(omega_n^k(k<frac n2))代入(:A(omega_n^k)=A_1(omega_n^{2k})+omega_n^kA_2(omega_n^{2k}))
将(omega_n^{k+frac n2})代入:(A(omega_n^{k+frac n2})=A_1(omega_n^{2k})-omega_n^kA_2(omega_n^{2k}))
发现只有一个符号不一样
于是求第一个式子时,我们可以(O(1))求第二个式子
我们就将这个问题缩小了一半
递归搞下去,就可以(O(nlogn))了
快速傅里叶逆变换
真的不想写了2333
跟上面其实差不多,直接看代码吧。。。
下面代码中
FFT(a, -1);
for (int i = 0; i <= M; i++) printf("%d ", (int)(a[i].x / N + 0.5));
是快速傅里叶逆变换
代码
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
namespace IO {
const int BUFSIZE = 1 << 20;
char ibuf[BUFSIZE], *is = ibuf, *it = ibuf;
inline char gc() {
if (is == it) it = (is = ibuf) + fread(ibuf, 1, BUFSIZE, stdin);
return *is++;
}
}
inline int gi() {
register int data = 0, w = 1;
register char ch = 0;
while (!isdigit(ch) && ch != '-') ch = IO::gc();
if (ch == '-') w = -1, ch = IO::gc();
while (isdigit(ch)) data = 10 * data + ch - '0', ch = IO::gc();
return w * data;
}
const double PI = acos(-1.0);
const int MAX_N = 3e6 + 5;
struct Complex { double x, y; } a[MAX_N], b[MAX_N];
Complex operator + (const Complex &a, const Complex &b) { return (Complex){a.x + b.x, a.y + b.y}; }
Complex operator - (const Complex &a, const Complex &b) { return (Complex){a.x - b.x, a.y - b.y}; }
Complex operator * (const Complex &a, const Complex &b) { return (Complex){a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x}; }
int N, M, P, r[MAX_N];
void FFT(Complex *p, int op) {
for (int i = 0; i < N; i++) if (i < r[i]) swap(p[i], p[r[i]]);
for (int i = 1; i < N; i <<= 1) {
Complex rot = (Complex){cos(PI / i), op * sin(PI / i)};
for (int j = 0; j < N; j += (i << 1)) {
Complex w = (Complex){1, 0};
for (int k = 0; k < i; ++k, w = w * rot) {
Complex x = p[j + k], y = w * p[j + k + i];
p[j + k] = x + y, p[j + k + i] = x - y;
}
}
}
}
int main () {
N = gi(), M = gi();
for (int i = 0; i <= N; i++) a[i].x = gi();
for (int i = 0; i <= M; i++) b[i].x = gi();
for (M += N, N = 1; N <= M; N <<= 1, ++P) ;
for (int i = 0; i < N; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (P - 1));
FFT(a, 1), FFT(b, 1);
for (int i = 0; i < N; i++) a[i] = a[i] * b[i];
FFT(a, -1);
for (int i = 0; i <= M; i++) printf("%d ", (int)(a[i].x / N + 0.5));
return 0;
}
NTT
其实和(FFT)差不多啦,
就是把单位根换为原根就行了
代码
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
inline int gi() {
register int data = 0, w = 1;
register char ch = 0;
while (!isdigit(ch) && ch != '-') ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (isdigit(ch)) data = 10 * data + ch - '0', ch = getchar();
return w * data;
}
const int MAX_N = 3e6 + 5, Mod = 998244353, G = 3, iG = 332748118;
int fpow(int x, int y) {
int res = 1;
while (y) {
if (y & 1) res = 1ll * res * x % Mod;
x = 1ll * x * x % Mod;
y >>= 1;
}
return res;
}
int Limit = 1, r[MAX_N];
void NTT(int *p, int op) {
for (int i = 0; i < Limit; i++) if (i < r[i]) swap(p[i], p[r[i]]);
for (int i = 1; i < Limit; i <<= 1) {
int rot = fpow(op == 1 ? G : iG, (Mod - 1) / (i << 1));
for (int j = 0, pls = (i << 1); j < Limit; j += pls) {
int w = 1;
for (int k = 0; k < i; k++, w = 1ll * w * rot % Mod) {
int x = p[j + k], y = 1ll * w * p[i + k + j] % Mod;
p[j + k] = (x + y) % Mod, p[i + j + k] = (x - y + Mod) % Mod;
}
}
}
}
int N, M, a[MAX_N], b[MAX_N];
int main () {
N = gi(), M = gi();
for (int i = 0; i <= N; i++) a[i] = (gi() + Mod) % Mod;
for (int i = 0; i <= M; i++) b[i] = (gi() + Mod) % Mod;
int L = 0;
while (Limit <= N + M) Limit <<= 1, ++L;
for (int i = 0; i < Limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(a, 1), NTT(b, 1);
for (int i = 0; i < Limit; i++) a[i] = 1ll * a[i] * b[i] % Mod;
NTT(a, -1);
int inv = fpow(Limit, Mod - 2);
for (int i = 0; i <= N + M; i++) printf("%lld ", (1ll * a[i] * inv) % Mod);
return 0;
}