zoukankan      html  css  js  c++  java
  • FFT&NTT小结

    Preface

    最近几天学了一下FFT和NTT,感觉这东西理解了之后也没有那么难 其实我IDFT还不会证明

    我本来是准备写一篇特别详细的总结,结果发现了一篇和我想写的内容相近的博客 传送门

    以及一篇只需初中数学知识的零基础学习笔记 传送门

    所以我就只讲一下大致的算法过程,具体可以去看一下链接的两篇博客

    先了解复数相关的性质和运算以及单位圆后,食用效果更佳

    Problem

    题目蓝链

    给你两个多项式,求它们的卷积

    Process

    DFT

    将给定的两个多项式从系数式转换为点值式

    点值式的意思是在平面上找到(n + 1)个横坐标不同的点来确定一个(n)次函数,即(F(x))对于若干个不同(x)的取值

    直接暴力显然是不行的,所以我们要想办法优化

    我们可以考虑把当前的问题转化为子问题,然后再从子问题快速求解当前的问题

    假设我们现在要求(F(x) = sumlimits_{i = 0}^{n - 1} a_i cdot x^i)的点值式,保证(n = 2^k (k in N))

    我们可以把式子变一下形

    [F(x) = (a_0 + a_2 cdot x^2 + cdots + a_{n - 2} cdot x^{n - 2}) + (a_1 cdot x + a_3 cdot x^3 + cdots + a_{n - 1} cdot x^{n - 1}) ]

    我们令

    [G(x) = a_0 + a_2 cdot x + cdots + a_{n - 2} cdot x^{frac{n}{2} - 1} \ G'(x) = a_1 + a_3 cdot x + cdots + a_{n - 1} cdot x^{frac{n}{2} - 1} ]

    [F(x) = G(x^2) + x cdot G'(x^2) ]

    但这样好像还是没有转换为一模一样的子问题,所以我们可以考虑带一些具有(qi)某些(qi)特殊(guai)性质(guai)的数值进去


    在经过前人无数次尝试之后,发现可以代入(omega)到式子里去,这是因为(omega)有一些比较神奇的性质

    (omega)的本质是一个复数,且满足(omega^n = 1),所以显然(omega)只能在单位圆上

    于是我们记(omega_n^k (k in [0, n)))为单位根,如果我们把这些单位根看成矢量,那么它们便会(n)等分这个单位圆

    它有这样一些性质:(字母均为整数)

    • (omega_n^k = omega_n^{k + a cdot n})
    • (omega_n^{k_1} cdot omega_n^{k_2}= omega_n^{k_1 + k_2})
    • (omega_{d cdot n}^{d cdot k} = omega_n^k)
    • (omega_n^{k + frac{n}{2}} = - omega_n^{k})
    • (sum_{i = 0}^{n - 1} (omega_{n}^{i})^k = 0, k e 0)

    其中第(2)条性质可以根据所有单位根都形如(omega = (cos alpha, sin alpha)),然后用三角函数的和角公式推一下就可以了

    对于第(5)条性质,可以直接等比数列求和证明

    [sum_{i = 0}^{n - 1} (omega_{n}^{i})^k = sum_{i = 0}^{n - 1} omega_{n}^{ik} = sum_{i = 0}^{n - 1} (omega_{n}^{k})^i = frac{1 - (omega_{n}^{k})^n}{1 - omega_{n}^{k}} = frac{1 - omega_{n}^{nk}}{1 - omega_{n}^{k}} = frac{1 - omega_{n}^{0}}{1 - omega_{n}^{k}} = 0 ]

    其余的几条性质比较显然,就不证明了


    所以,(保证(k < frac{n}{2}))

    [F(omega_n^k) = G((omega_n^k)^2) + omega_n^k cdot G'((omega_n^k)^2) \ = G(omega_n^{2k}) + omega_n^k cdot G'(omega_n^{2k}) \ = G(omega_{n / 2}^k) + omega_n^k cdot G'(omega_{n / 2}^k) ]

    因为(omega_n^{k + n / 2} = - omega_n^k),所以

    [F(omega_n^{k + n / 2}) = G(omega_{n / 2}^k) - omega_n^k cdot G'(omega_{n / 2}^k) ]

    综上,(F)函数的点值均可从函数(G)(G')转移过来,复杂度(mathcal{O}(n))

    分治后总时间复杂度(mathcal{O}(n log n)) (此(n)非彼(n))

    点值式相乘

    我们直接把两个点值式乘起来即为它们卷集的点值式

    IDFT

    将答案的点值式转换为系数式

    结论:只需要令(omega_n^k = omega_n^{-k})后进行一次DFT即为IDFT

    我们设多项式为(F(x)),点值数列为(G(x)),有(G = DFT(F)),即:

    [egin{gather} G(x) = sum_{i = 0}^{n - 1} (omega_n^x)^i F(i) otag \ F(x) = frac{sum_{i = 0}^{n - 1}(omega_n^{-x})^i G(i)}{n} ag{1} end{gather} ]

    我们可以结论来往回推验证结论

    [sum_{i = 0}^{n - 1}(omega_n^{-x})^i G(i) = sum_{i = 0}^{n - 1} sum_{j = 0}^{n - 1} omega_n^{i(j-x)} F(j) = sum_{j = 0}^{n - 1} F(j) (sum_{i = 0}^{n - 1} omega_n^{i(j-x)}) ag{2} ]

    (j = x)时,后半部分(括号内的)等于(n imes omega_n^0 = n)

    (j e x)时,设(d = j - x),并且已知之前讲的第(5)条性质

    [sum_{i = 0}^{n - 1} omega_n^{i(j-x)} = sum_{i = 0}^{n - 1} (omega_n^i)^d = 0 ]

    所以对于(forall j e x),均满足(sum_{i = 0}^{n - 1} omega_n^{i(j-x)} = 0)

    综上,((2))式的值为(n imes F(x)),(注意此时的(x = j))。然后我们再把((2))回带到((1))中,得到(F(x) = F(x)),证明结论正确

    至此,我们便得到了答案多项式的系数

    二进制翻转

    由于我们在递归的时候,每次都需要把系数按照奇偶分类传给左右区间,直接搞会比较麻烦。所以我们可以事先观察一下每一个系数会到哪里去,并在(DFT)之前移动好

    我们发现每次划分都是看最低位,把(0)的分到左区间,(1)分到右区间,然后去掉最低位重复这个过程。所以我们就可以对于系数按照其下标二进制反转的值的大小排序,就可以知道每一个系数的最终位置了

    Code

    FFT
    #include <bits/stdc++.h>
    
    using namespace std;
    
    #define fst first
    #define snd second
    #define mp make_pair
    #define squ(x) ((LL)(x) * (x))
    #define debug(...) fprintf(stderr, __VA_ARGS__)
    
    typedef long long LL;
    typedef pair<int, int> pii;
    
    template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
    template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
    
    inline int read() {
    	int sum = 0, fg = 1; char c = getchar();
    	for (; !isdigit(c); c = getchar()) if (c == '-') fg = -1;
    	for (; isdigit(c); c = getchar()) sum = (sum << 3) + (sum << 1) + (c ^ 0x30);
    	return fg * sum;
    }
    
    namespace FFT {
    
    	const int MAX_LEN = (1 << 21) + 5;
    	const double PI = acos(-1.0);
    
    	struct com {
    		double a, b;
    		com (double _a = 0.0, double _b = 0.0): a(_a), b(_b) { }
    		com operator + (const com &t) const { return com(a + t.a, b + t.b); }
    		com operator - (const com &t) const { return com(a - t.a, b - t.b); }
    		com operator * (const com &t) const { return com(a * t.a - b * t.b, a * t.b + b * t.a); }
    	};
    
    	int len, cnt, rev[MAX_LEN];
    	com g[MAX_LEN];
    
    	inline void init(int N) {
    		for (cnt = -1, len = 1; len <= N; len <<= 1) ++cnt;
    		for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << cnt);
    		g[0] = com(1.0, 0.0);
    		for (int i = 1; i <= len; i++) g[i] = com(cos(PI * 2 / len * i), sin(PI * 2 / len * i));
    	}
    
    	inline void DFT(com *x, int op) {
    		for (int i = 0; i < len; i++) if (i < rev[i]) swap(x[i], x[rev[i]]);
    		for (int k = 2; k <= len; k <<= 1)
    			for (int j = 0; j < len; j += k)
    				for (int i = 0; i < k / 2; i++) {
    					com X = x[j + i], Y = x[j + i + k / 2] * g[len / k * (~op ? i : k - i)];
    					x[j + i] = X + Y, x[j + i + k / 2] = X - Y;
    				}
    		if (op == -1) for (int i = 0; i < len; i++) x[i].a /= len, x[i].b /= len;
    	}
    
    	inline void mul(int *a, int n, int *b, int m, int *c) {
    		if (n + m == 0) { c[0] = a[0] * b[0]; return; }
    		init(n + m);
    		static com F[MAX_LEN], G[MAX_LEN], S[MAX_LEN];
    		for (int i = 0; i < len; i++) F[i] = com(i <= n ? a[i] : 0.0, 0.0);
    		for (int i = 0; i < len; i++) G[i] = com(i <= m ? b[i] : 0.0, 0.0);
    		DFT(F, 1), DFT(G, 1);
    		for (int i = 0; i < len; i++) S[i] = F[i] * G[i];
    		DFT(S, -1);
    		for (int i = 0; i <= n + m; i++) c[i] = round(S[i].a);
    	}
    
    }
    
    const int maxn = 2e6 + 10;
    
    int main() {
    #ifdef xunzhen
    	freopen("FFT.in", "r", stdin);
    	freopen("FFT.out", "w", stdout);
    #endif
    
    	int n = read(), m = read();
    
    	static int a[maxn], b[maxn], c[maxn];
    	for (int i = 0; i <= n; i++) a[i] = read();
    	for (int i = 0; i <= m; i++) b[i] = read();
    
    	FFT::mul(a, n, b, m, c);
    
    	for (int i = 0; i <= n + m; i++) printf("%d%c", c[i], i < n + m ? ' ' : '
    ');
    
    	return 0;
    }
    
    NTT
    #include <bits/stdc++.h>
    
    using namespace std;
    
    #define fst first
    #define snd second
    #define mp make_pair
    #define squ(x) ((LL)(x) * (x))
    #define debug(...) fprintf(stderr, __VA_ARGS__)
    
    typedef long long LL;
    typedef pair<int, int> pii;
    
    template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
    template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
    
    inline int read() {
    	int sum = 0, fg = 1; char c = getchar();
    	for (; !isdigit(c); c = getchar()) if (c == '-') fg = -1;
    	for (; isdigit(c); c = getchar()) sum = (sum << 3) + (sum << 1) + (c ^ 0x30);
    	return fg * sum;
    }
    
    namespace NTT {
    
    	const int MAX_LEN = 1 << 21;
    	const int mod = 998244353, g0 = 3;
    
    	int len, cnt, rev[MAX_LEN], g[MAX_LEN];
    
    	inline int add(int x, int y) { return (x += y) < mod ? (x >= 0 ? x : x + mod) : x - mod; }
    	inline int mul(int x, int y) { return (LL)x * y % mod; }
    	inline int Pow(int x, int y) {
    		if (y < 0) y = -1LL * y * (mod - 2) % (mod - 1);
    		int res = 1;
    		for (; y; y >>= 1, x = mul(x, x)) if (y & 1) res = mul(res, x);
    		return res;
    	}
    
    	void init(int N) {
    		for (cnt = -1, len = 1; len <= N; len <<= 1) ++cnt;
    		for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << cnt);
    		g[0] = 1;
    		for (int G = Pow(g0, (mod - 1) / len), i = 1; i < len; i++) g[i] = mul(g[i - 1], G);
    	}
    
    	void DFT(int *x, int op) {
    		for (int i = 0; i < len; i++) if (i < rev[i]) swap(x[i], x[rev[i]]);
    		for (int k = 2; k <= len; k <<= 1)
    			for (int j = 0; j < len; j += k)
    				for (int i = 0; i < k / 2; i++) {
    					int X = x[j + i], Y = mul(x[j + i + k / 2], g[~op ? len / k * i : len / k * (i ? k - i : i)]);
    					x[j + i] = add(X, Y), x[j + i + k / 2] = add(X, -Y);
    				}
    		if (op == -1) for (int inv = Pow(len, -1), i = 0; i < len; i++) x[i] = mul(x[i], inv);
    	}
    
    	void mul(int *a, int n, int *b, int m, int *c) {
    		init(n + m);
    		static int F[MAX_LEN], G[MAX_LEN], S[MAX_LEN];
    		for (int i = 0; i < len; i++) F[i] = i <= n ? a[i] : 0;
    		for (int i = 0; i < len; i++) G[i] = i <= m ? b[i] : 0;
    		DFT(F, 1), DFT(G, 1);
    		for (int i = 0; i < len; i++) S[i] = mul(F[i], G[i]);
    		DFT(S, -1);
    		for (int i = 0; i <= n + m; i++) c[i] = S[i];
    	}
    
    }
    
    const int maxn = 2e6 + 10;
    
    int main() {
    #ifdef xunzhen
    	freopen("NTT.in", "r", stdin);
    	freopen("NTT.out", "w", stdout);
    #endif
    
    	int n = read(), m = read();
    
    	static int a[maxn], b[maxn], c[maxn];
    	for (int i = 0; i <= n; i++) a[i] = read();
    	for (int i = 0; i <= m; i++) b[i] = read();
    
    	NTT::mul(a, n, b, m, c);
    
    	for (int i = 0; i <= n + m; i++) printf("%d%c", c[i], i < n + m ? ' ' : '
    ');
    
    	return 0;
    }
    

    Summary

    其实NTT就是把FFT在模意义下进行,我们可以找一个原根(g)来代替(omega)

    NTT可以用来避免浮点数的缓慢运算 但好像取模运算更满(雾

    IDFT就先留个坑,等以后再来填算了((filled))

    19.2.4upd

    其实g[~op ? len / k * i : len / k * (i ? k - i : i)还可以写成g[len / k * (~op ? i : (i ? k - i : i))]

    19.2.14upd

    恩,之前的FFT板子被某位大佬说会掉精度,所以我就修改了一下板子

    19.11.26upd

    更新了一些地方,以及补充了一些东西的证明

  • 相关阅读:
    第一阶段大作业 文件上传格式
    第一阶段大作业 数据字典的修改
    设计模式 C++实现职责链模式 (顺便复习C++)
    Numpy学习
    2019版:第二章:(1)Redis 概述
    第一章:(6)Dubbo 与 SpringBoot 整合
    第一章:(5)Dubbo 监控中心
    2019版:第一章:(2)NOSQL 数据库
    2019版:第二章:(3)Redis 其他相关知识
    2019版:第一章:(1)技术发展
  • 原文地址:https://www.cnblogs.com/xunzhen/p/10350797.html
Copyright © 2011-2022 走看看