zoukankan      html  css  js  c++  java
  • FFT/NTT

    完全抄袭自 OI-wiki


    基本

    通俗地说, 系数表达 → 点值表达, 称为 DFT, 点值表达 → 系数表达, 称为 IDFT。

    FFT 通过取某些特殊的 x 的点值来加速 DFT 和 IDFT。

    考虑点值表示下的多项式乘法:

    [f(x) = (x_0,f(x_0)),(x_1,f(x_1)),cdots,(x_n,f(x_n)) \ g(x) = (x_0,g(x_0)),(x_1,g(x_1)),cdots,(x_n,g(x_n)) \ (fcdot g)(x) = f(x)cdot g(x) \ (fcdot g)(x) = (x_0,f(x_0)g(x_0)), (x_1,f(x_1)g(x_1)), cdots,(x_n,f(x_n)g(x_n)) ]

    明显是 O(n) 的。

    如此,通过 FFT, 可以实现快速的多项式乘法。


    分治结构

    [egin{align} f(x) &= a_0 + a_1x+a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7 \ &= (a_0+a_2x^2+a_4x^4+a_6x^6) + xcdot(a_1+a_3x^2+a_5x^4+a_7x^6) end{align} ]

    设个 (g(x) = a_0 + a_2x + a_4x^2 + a_6x^3), 再设个 (h(x) = a_1 + a_3x + a_5x^2 + a_7x^3), 就有:

    [f(x) = g(x^2) + xcdot h(x^2) ]

    接下来是精彩的地方, 前面说的特殊点值要发挥作用了。

    带入 n 次的某个单位根, 首先有:

    [egin{align} f(omega_n^k) &= g(omega_n^{2k}) + omega_n^k cdot h(omega_n^{2k}) \ &= g(omega_{n/2}^k) + omega_n^kcdot h(omega_{n/2}^k) end{align} ]

    然后有:

    [egin{align} f(omega_n^{k + n/2}) &= g(omega_{n}^{2k+n}) + omega_{n}^{k+n/2}cdot h(omega_n^{2k+n}) \ &= g(omega_{n/2}^k) - omega_n^kcdot h(omega_{n/2}^k) end{align} ]

    这个分治的结构就清晰可见了, 虽然本质什么的还不是很清楚, 但可以窥见一丝构造的痕迹。


    IDFT

    带入单位根的共轭复数 DFT 一下, 再把得到的东西除以 n 就行了。


    代码

    抄的学长的实现, 是有点优化的写法。目前没考虑到封装。

    不用算 rev 的 FFT 真的那么 dio 吗?

    #include<bits/stdc++.h>
    
    using namespace std;
    
    int rd() {
    	int x = 0;
    	char c = getchar();
    	while(c<'0' || c>'9') c=getchar();
    	while(c>='0' && c<='9') x=x*10+c-'0', c=getchar();
    	return x;
    }
    
    const int N = (1<<21)+ 233;
    const double pi = acos(-1);
    
    struct com {
    	double x, y;
    	com(double a, double b) : x(a), y(b) {
    	}
    	com() {
    		x=y=0;
    	}
    	const com operator+(const com rhs) const{
    		return com(x+rhs.x, y+rhs.y);
    	}
    	const com operator-(const com rhs) const{
    		return com(x-rhs.x, y-rhs.y);
    	}
    	const com operator*(const com rhs) const{
    		return com(x*rhs.x - y*rhs.y, x*rhs.y + y*rhs.x);
    	}
    };
    
    int n, m, rv[N];
    com a[N], b[N];
    
    void fft(com *a, int n, int type) {
    	for(int i=0; i<n; ++i) if(i<=rv[i]) swap(a[i], a[rv[i]]);
    	for(int m=2; m<=n; m<<=1) {
    		com w(cos(2 * pi / m), type * sin(2 * pi / m));
    		for(int i=0; i<n; i += m) {
    			com tmp = com(1, 0);
    			for(int j=0; j<(m>>1); ++j) {
    				com p = a[i+j], q = tmp * a[i+j+(m>>1)];
    				a[i+j] = p + q,
    				a[i+j+(m>>1)] = p - q;
    				tmp = tmp * w;
    			}
    		}
    	}
    }
    
    int main() {
    	n = rd()+1, m = rd() + 1;
    	for(int i=0; i<n; ++i) a[i].x = rd();
    	for(int i=0; i<m; ++i) b[i].x = rd();
    	for(m=n+m-1, n=1; n<m; n=n<<1);
    	for(int i=0; i<n; ++i) rv[i] = (rv[i>>1]>>1)|(i&1?(n>>1):0);
    	fft(a, n, 1);
    	fft(b, n, 1);
    	for(int i=0; i<n; ++i) a[i] = a[i] * b[i];
    	fft(a, n, -1);
    	for(int i=0; i<m; ++i) printf("%d ", (int)(a[i].x/n+0.5));
    	return 0;
    }
    

    现在正要加速学习多项式, 所以 NTT 就先背个版吧。

    #include<bits/stdc++.h>
    typedef long long LL;
    using namespace std;
    
    const int N = 3e6 + 23, mo = 998244353, g = 3;
    
    int read() { 
        char c = getchar(); int x = 0;
        while(c < '0' || c > '9') c = getchar();
        while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
        return x;
    }
    
    int ksm(int a, int b) {
    	int res = 1;
    	for(; b; b=b>>1, a=((LL)a*a) % mo)
    		if(b & 1) res = (LL)res * a % mo;
    	return res;
    }
    
    const int ig = ksm(g, mo-2);
    
    int n, m, a[N], b[N], rv[N];
    
    void ntt(int *a, int n, int type) {
    	for(int i=0; i<n; ++i) if(i<rv[i]) swap(a[i], a[rv[i]]);
    	for(int m=2; m<=n; m<<=1) {
    		int w = ksm(type == 1 ? g : ig, (mo-1)/m);
    		for(int i = 0; i < n; i += m) {
    			int tmp = 1;
    			for(int j = 0; j < (m>>1); ++j) {
    				int p = a[i+j], q = (LL)tmp * a[i+j+(m>>1)] % mo;
    				a[i + j] = (p + q) % mo, a[i + j + (m>>1)] = (p - q + mo) % mo;
    				tmp = (LL)tmp * w % mo;
    			}
    		}
    	}
    }
    
    int main() {
    	n = read()+1, m = read()+1;
    	for(int i=0;i<n;++i) a[i]=read();
    	for(int i=0;i<m;++i) b[i]=read();
    	for(m=n+m-1,n=1;n<m;n<<=1);
    	for(int i=0;i<n;++i) rv[i] = (rv[i>>1]>>1)|((i&1)?(n>>1):0);
    	ntt(a, n, 1), ntt(b, n, 1);
    	for(int i = 0; i < n; ++i) a[i] = (LL)a[i] * b[i] % mo;
    	ntt(a, n, -1);
    	int inv = ksm(n, mo-2);
    	for(int i=0; i<m; ++i) cout << (LL)a[i] * inv % mo << ' ';
    	return 0;
    }
    
  • 相关阅读:
    java学习阶段一 方法和文档注释
    java学习阶段一 二维数组
    java学习阶段一 一维数组
    java学习阶段一 循环结构
    java学习阶段一 选择结构
    java学习阶段一 运算符
    oracle学习笔记:修改表空间文件位置
    oracle学习笔记:重建临时表空间
    oracle等待事件1:Failed Logon delay等待事件
    oracle数据库删除归档日志
  • 原文地址:https://www.cnblogs.com/tztqwq/p/14323397.html
Copyright © 2011-2022 走看看