zoukankan      html  css  js  c++  java
  • FFT & NTT

    前言

    又到了愉快的数学时间了!

    由于蒟蒻作者数学特别烂,贴个友链然后直接上代码吧

    本博客由某谷搬运过来,目的只是存板子。

    (update 2021.2.2) 更新了部分代码,已经基本学懂,懒得更新讲解了。u1s1,到了高中后有了一定数学基础,就是比初中傻白甜的时候学得快。

    (update 2021.2.3) 学懂了蝴蝶变换之后又更新了一波板子,不得不说,迭代版本真的快得离谱。

    (update 2021.7.27) 折叠了代码,增加文章可读性。

    FFT

    友链

    Aaplloo

    OneInDark

    练习

    板题(UOJ)

    板题(洛谷)

    力(洛谷)

    代码

    板题代码

    $Mine ext{(递归)}$
    //12252024832524
    #include <cmath>
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #define TT template<typename T>
    using namespace std; 
    
    typedef long long LL;
    const int MAXN = 1 << 21 | 5;
    const double PI = acos(-1);
    int lena,lenb;
    
    LL Read()
    {
    	LL x = 0,f = 1;char c = getchar();
    	while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
    	while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
    	return x * f;
    }
    TT void Put1(T x)
    {
    	if(x > 9) Put1(x/10);
    	putchar(x%10^48);
    }
    TT void Put(T x,char c = -1)
    {
    	if(x < 0) putchar('-'),x = -x;
    	Put1(x);
    	if(c >= 0) putchar(c);
    }
    TT T Max(T x,T y){return x > y ? x : y;}
    TT T Min(T x,T y){return x < y ? x : y;}
    TT T Abs(T x){return x < 0 ? -x : x;}
    
    struct cp//complex
    {
    	double x,y;
    	cp(){}
    	cp(double x1,double y1){
    		x = x1;
    		y = y1;
    	}
    	cp operator + (const cp &A) const {return cp(x+A.x,y+A.y);}
    	cp operator - (const cp &A) const {return cp(x-A.x,y-A.y);}
    	cp operator * (const cp &A) const {return cp(x*A.x-y*A.y,x*A.y+y*A.x);}
    }a[MAXN],b[MAXN];
    
    void FFT(int len,cp * a,int f)
    {
    	if(len == 1) return;
    	cp a1[len >> 1],a2[len >> 1];
    	for(int i = 0;i < len;i += 2) a1[i >> 1] = a[i],a2[i >> 1] = a[i+1];
    	FFT(len>>1,a1,f);
    	FFT(len>>1,a2,f);
    	cp w = cp(cos(2*PI/len),f*sin(2*PI/len)),k = cp(1,0);
    	len >>= 1;
    	for(int i = 0;i < len;++ i,k = k * w)
    	{
    		a[i] = a1[i] + k * a2[i];
    		a[i+len] = a1[i] - k * a2[i];
    	}
    }
    
    int main()
    {
    //	freopen(".in","r",stdin);
    //	freopen(".in","w",stdout);
    	lena = Read(); lenb = Read();
    	for(int i = 0;i <= lena;++ i) a[i].x = Read();
    	for(int i = 0;i <= lenb;++ i) b[i].x = Read();
    	int len = 1;
    	while(len <= lena + lenb) len <<= 1;
    	FFT(len,a,1);
    	FFT(len,b,1);
    	for(int i = 0;i <= len;++ i) a[i] = a[i] * b[i];
    	FFT(len,a,-1);
    	for(int i = 0;i <= lena+lenb;++ i) Put((int)(a[i].x/len + 0.5),' ');
    	return 0;
    }
    
    $Mine ext{(迭代)}$
    //12252024832524
    #include <cmath>
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #define TT template<typename T>
    using namespace std; 
    
    typedef long long LL;
    const int MAXN = 1 << 21 | 5;
    const double PI = acos(-1);
    int lena,lenb,len = 1,l = -1;
    int rev[MAXN];
    
    LL Read()
    {
    	LL x = 0,f = 1;char c = getchar();
    	while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
    	while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
    	return x * f;
    }
    TT void Put1(T x)
    {
    	if(x > 9) Put1(x/10);
    	putchar(x%10^48);
    }
    TT void Put(T x,char c = -1)
    {
    	if(x < 0) putchar('-'),x = -x;
    	Put1(x);
    	if(c >= 0) putchar(c);
    }
    TT T Max(T x,T y){return x > y ? x : y;}
    TT T Min(T x,T y){return x < y ? x : y;}
    TT T Abs(T x){return x < 0 ? -x : x;}
    
    struct cp
    {
    	double x,y;
    	cp(){}
    	cp(double x1,double y1){
    		x = x1;
    		y = y1;
    	}
    	cp operator + (const cp &A)const{return cp(x+A.x,y+A.y);}
    	cp operator - (const cp &A)const{return cp(x-A.x,y-A.y);}
    	cp operator * (const cp &A)const{return cp(x*A.x-y*A.y,x*A.y+y*A.x);}
    }a[MAXN],b[MAXN];
    
    void FFT(cp *a,int opt)
    {
    	for(int i = 0;i < len;++ i) if(i < rev[i]) swap(a[i],a[rev[i]]);
    	for(int i = 1;i < len;i <<= 1)
    	{
    		cp w = cp(cos(PI/i),opt*sin(PI/i));
    		for(int j = 0,p = i << 1;j < len;j += p)
    		{
    			cp s = cp(1,0);
    			for(int k = 0;k < i;++ k,s = s * w)
    			{
    				cp X = a[j+k],Y = s * a[i+j+k];
    				a[j+k] = X + Y;
    				a[i+j+k] = X - Y;
    			}
    		}
    	}
    	if(opt == -1) for(int i = 0;i < len;++ i) a[i].x /= len;
    }
    
    int main()
    {
    //	freopen(".in","r",stdin);
    //	freopen(".out","w",stdout);
    	lena = Read(); lenb = Read();
    	for(int i = 0;i <= lena;++ i) a[i].x = Read();
    	for(int i = 0;i <= lenb;++ i) b[i].x = Read();
    	while(len <= lena + lenb) len <<= 1,l++;
    	for(int i = 0;i < len;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
    	FFT(a,1);
    	FFT(b,1);
    	for(int i = 0;i < len;++ i) a[i] = a[i] * b[i];
    	FFT(a,-1);
    	for(int i = 0;i <= lena+lenb;++ i) Put((int)(a[i].x + 0.5),' ');
    	return 0;
    }
    
    $French'sspace code ext{(迭代)}$
    #include<iostream>
    #include<algorithm>
    #include<cmath>
    using namespace std;
    #define maxn 10000005
    #define x first
    #define y second
    const double pi=acos(-1.0);
    int n,m;
    int limit=1;
    pair<double,double> a[maxn];
    pair<double,double> b[maxn];
    int l;
    int r[maxn];
    pair<double,double> operator + (pair<double,double> a,pair<double,double> b)
    {
    	return make_pair(a.x+b.x,a.y+b.y);
    }
    pair<double,double> operator - (pair<double,double> a,pair<double,double> b)
    {
    	return make_pair(a.x-b.x,a.y-b.y);
    }
    pair<double,double> operator * (pair<double,double> a,pair<double,double> b)
    {
    	return make_pair(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
    }
    #undef x
    #undef y
    void fft(pair<double,double> *a,int t)
    {
    	for(int i=0;i<limit;i++)
    		if(i<r[i])
    			swap(a[i],a[r[i]]);
    	for(int mid=1;mid<limit;mid<<=1)
    	{
    		pair<double,double> wn=make_pair(cos(pi/mid),t*sin(pi/mid));
    		for(int r=mid<<1,j=0;j<limit;j+=r)
    		{
    			pair<double,double> w=make_pair(1,0);
    			for(int k=0;k<mid;k++,w=w*wn)
    			{
    				pair<double,double> x=a[j+k],y=w*a[j+mid+k];
    				a[j+k]=x+y;
    				a[j+mid+k]=x-y;
    			}
    		}
    	}
    }
    void work()
    {
    	fft(a,1);
    	fft(b,1);
    	for(int i=0;i<=limit;i++)
    		a[i]=a[i]*b[i];
    	fft(a,-1);
    	for(int i=0;i<=n+m;i++)
    		cout<<(int)(a[i].first/limit+0.5)<<' ';
    }
    void perpare()
    {
    	while(limit<=n+m)
    	{
    		limit<<=1;
    		l++;
    	}
    	for(int i=0;i<limit;i++)
    		r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    }
    int main()
    {
    	ios::sync_with_stdio(false);
    	cin>>n>>m;
    	for(int i=0;i<=n;i++)
    		cin>>a[i].first;
    	for(int i=0;i<=m;i++)
    		cin>>b[i].first;
    	perpare();
    	work();
    	return 0;
    }
    

    NTT

    正题

    难得一见的讲解:

    由于 (FFT) 会有精度问题,而且不能取模,所以 (NTT) 就诞生了。

    我们只需将 (FFT) 中的 (omega) 换成 (NTT) 中的模数的原根 (g) 就好了。

    如果我们不需要取模,只需要找一个很大的模数就好了,这样取模就相当于没有取模。

    当然最后除 (len) 的时候改为乘逆元就好了。

    练习

    板题(UOJ)

    板题(洛谷)

    其实你可以用 (NTT) 过所有 (FFT) 的题。 好像并不是

    代码

    板题代码

    $Mine ext{(递归)}$
    //12252024832524
    #include <cmath>
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #define TT template<typename T>
    using namespace std; 
    
    typedef long long LL;
    const int MAXN = 1 << 21 | 5;
    const int MOD = 998244353;
    const int PHI = 998244352;
    const int GINV = 332748118;
    const int G = 3;
    int lena,lenb;
    int a[MAXN],b[MAXN];
    
    LL Read()
    {
    	LL x = 0,f = 1;char c = getchar();
    	while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
    	while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
    	return x * f;
    }
    TT void Put1(T x)
    {
    	if(x > 9) Put1(x/10);
    	putchar(x%10^48);
    }
    TT void Put(T x,char c = -1)
    {
    	if(x < 0) putchar('-'),x = -x;
    	Put1(x);
    	if(c >= 0) putchar(c);
    }
    TT T Max(T x,T y){return x > y ? x : y;}
    TT T Min(T x,T y){return x < y ? x : y;}
    TT T Abs(T x){return x < 0 ? -x : x;}
    
    int qpow(int x,int y)
    {
    	int ret = 1;
    	while(y){if(y & 1) ret = 1ll * ret * x % MOD;x = 1ll * x * x % MOD;y >>= 1;}
    	return ret;
    }
    void NTT(int len,int *a,int f)
    {
    	if(len == 1) return;
    	int a1[len >> 1],a2[len >> 1];
    	for(int i = 0;i < len;i += 2) a1[i >> 1] = a[i],a2[i >> 1] = a[i+1];
    	NTT(len>>1,a1,f);
    	NTT(len>>1,a2,f);
    	int w = qpow(f == 1 ? G : GINV,PHI/len),k = 1;
    	len >>= 1;
    	for(int i = 0;i < len;++ i,k = 1ll * k * w % MOD)
    	{
    		a[i] = (a1[i] + 1ll * k * a2[i]) % MOD;
    		a[i+len] = (a1[i] - 1ll * k * a2[i]) % MOD;
    	}
    }
    
    int main()
    {
    //	freopen(".in","r",stdin);
    //	freopen(".in","w",stdout);
    	lena = Read(); lenb = Read();
    	for(int i = 0;i <= lena;++ i) a[i] = Read();
    	for(int i = 0;i <= lenb;++ i) b[i] = Read();
    	int len = 1;
    	while(len <= lena + lenb) len <<= 1;
    	NTT(len,a,1);
    	NTT(len,b,1);
    	for(int i = 0;i <= len;++ i) a[i] = 1ll * a[i] * b[i] % MOD;
    	NTT(len,a,-1);
    	const int invlen = qpow(len,MOD-2);
    	for(int i = 0;i <= lena+lenb;++ i) Put((1ll * a[i] * invlen % MOD + MOD) % MOD,' ');
    	return 0;
    }
    
    $Mine ext{(迭代)}$
    //12252024832524
    #include <cmath>
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #define TT template<typename T>
    using namespace std; 
    
    typedef long long LL;
    const int MAXN = 1 << 21 | 5;
    const int MOD = 998244353;
    const int PHI = 998244352;
    const int GINV = 332748118;
    const int G = 3;
    int lena,lenb,len = 1,l = -1;
    int a[MAXN],b[MAXN],rev[MAXN];
    
    LL Read()
    {
    	LL x = 0,f = 1;char c = getchar();
    	while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
    	while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
    	return x * f;
    }
    TT void Put1(T x)
    {
    	if(x > 9) Put1(x/10);
    	putchar(x%10^48);
    }
    TT void Put(T x,char c = -1)
    {
    	if(x < 0) putchar('-'),x = -x;
    	Put1(x);
    	if(c >= 0) putchar(c);
    }
    TT T Max(T x,T y){return x > y ? x : y;}
    TT T Min(T x,T y){return x < y ? x : y;}
    TT T Abs(T x){return x < 0 ? -x : x;}
    
    int qpow(int x,int y)
    {
    	int ret = 1;
    	while(y){if(y & 1) ret = 1ll * ret * x % MOD;x = 1ll * x * x % MOD;y >>= 1;}
    	return ret;
    }
    void NTT(int *a,int opt)
    {
    	for(int i = 0;i < len;++ i) if(i < rev[i]) swap(a[i],a[rev[i]]);
    	for(int i = 1;i < len;i <<= 1)
    	{
    		int w = qpow(opt == 1 ? G : GINV,PHI / (i << 1));
    		for(int j = 0,p = i << 1;j < len;j += p)
    		{
    			int mi = 1;
    			for(int k = 0;k < i;++ k,mi = 1ll * mi * w % MOD)
    			{
    				int X = a[j+k],Y = 1ll * mi * a[i+j+k] % MOD;
    				a[j+k] = (X + Y) % MOD;
    				a[i+j+k] = (X - Y + MOD) % MOD;
    			}
    		}
    	}
    	int invlen = qpow(len,MOD-2);
    	if(opt == -1) for(int i = 0;i < len;++ i) a[i] = 1ll * a[i] * invlen % MOD;
    }
    
    int main()
    {
    //	freopen(".in","r",stdin);
    //	freopen(".out","w",stdout);
    	lena = Read(); lenb = Read();
    	for(int i = 0;i <= lena;++ i) a[i] = Read();
    	for(int i = 0;i <= lenb;++ i) b[i] = Read();
    	while(len <= lena + lenb) len <<= 1,l++;
    	for(int i = 0;i < len;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
    	NTT(a,1);
    	NTT(b,1);
    	for(int i = 0;i < len;++ i) a[i] = 1ll * a[i] * b[i] % MOD;
    	NTT(a,-1);
    	for(int i = 0;i <= lena+lenb;++ i) Put(a[i],' ');
    	return 0;
    }
    
  • 相关阅读:
    hibernate对应关系详解(转)
    mybatis genertor两种使用方式(文件+项目)
    YII2 union 不同数据结构时 解决方案
    Yii2 分表后 使用 union all 分页实现代码
    Beyond Compare 4.2.10手动破解
    Xshell 6+Xftp 6官方下载免费版
    Navicat Premium
    yii2的Console定时任务创建
    内嵌多个youtube视频并展现该频道所有视频列表
    video.js 动态获取URL 并播放youtube视频
  • 原文地址:https://www.cnblogs.com/PPLPPL/p/14362284.html
Copyright © 2011-2022 走看看