zoukankan      html  css  js  c++  java
  • FFT迭代加深 & NTT 多项式求逆

    NTT板子
    又重温了一遍,大佬说背锅就好
    具体看代码

    想要看懂NTT板子,先看懂FFT迭代加深模板;

    FFT迭代加深版本

    #include<iostream>
    #include<cstdio>
    #include<cmath>
    using namespace std;
    const int N=1e7+7;
    struct complex{
    	double x,y;
    	complex(double xx=0,double yy=0) {x=xx,y=yy;}
    }a[N],b[N];
    const double pi=acos(-1.0);
    complex operator +(const complex a,complex b) {return complex(a.x+b.x,a.y+b.y);}
    complex operator -(const complex a,complex b) {return complex(a.x-b.x,a.y-b.y);}
    complex operator *(const complex a,complex b) {return complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
    int limit=1,n,m,l;
    int r[N];
    
    void FFT(complex *a,int f){
    	for(int i=0;i<limit;i++) if(i<r[i]) swap(a[i],a[r[i]]);
    	for(int mid=1;mid<limit;mid<<=1){//枚举要合并的区间的长度
    		complex Wn=complex(cos(pi/mid),f*sin(pi/mid));//单位根
    		for(int R=mid<<1,j=0;j<limit;j+=R){
    			complex w(1,0);
    			for(int k=0;k<mid;k++,w=w*Wn){
    				complex x=a[j+k],y=w*a[j+mid+k];
    				a[j+k]=x+y;
    				a[j+mid+k]=x-y;
    			}
    		}
    	}
    }
    
    int main(){
    	scanf("%d%d",&n,&m); 
    	for(int i=0;i<=n;i++) scanf("%lf",&a[i].x);
    	for(int i=0;i<=m;i++) scanf("%lf",&b[i].x);
    	while(limit<=n+m) limit<<=1,l++;
    	for(int i=0;i<limit;i++){
    		r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    	}
    	
    	NTT(a,1);
    	NTT(b,1);
    	for(int i=0;i<=limit;i++) a[i]=a[i]*b[i];
    	NTT(a,-1);
    	for(int i=0;i<=n+m;i++){
    		cout<<(int)(a[i].x/(limit)+0.5)<<" ";
    	}
    }
    

    多项式求逆

    #include<iostream>
    #include<cstdio>
    using namespace std;
    #define int long long
    const int N=1e6+7;
    const int p=998244353;//原根为3
    int n;
    int a[N],b[N],c[N],r[N];
    int ksm(int a,int b){
    	int res=1;
    	for(;b;b>>=1){
    		if(b&1) res=res*a%p;
    		a=a*a%p;
    	}
    	return res;
    }
    
    void NTT(int *a,int len,int opt){
    	for(int i=0;i<len;i++) if(i<r[i]) swap(a[i],a[r[i]]);
    	for(int h=1;h<len;h<<=1){
    		int Wn=ksm(3,(p-1)/(h<<1));
    		if(opt==-1) Wn=ksm(Wn,(p-2));//NTT求原根
    		for(int j=0;j<len;j+=(h<<1)){
    			int w=1;
    			for(int k=0;k<h;k++){
    				int x=a[j+k];
    				int y=w*a[j+h+k] % p;
    				a[j+k]=(x+y)%p;
    				a[j+h+k]=(x-y+p)%p;
    				w=w*Wn%p;
    			}
    		}
    	}
    	if(opt==-1){
    		int inv=ksm(len,p-2);
    		for(int i=0;i<len;i++){
    			a[i]=a[i]*inv%p;
    		}
    	}
    }
    
    void INV(int n,int *a,int *b){
    	if(n==1){
    		b[0]=ksm(a[0],p-2);
    		return;
    	}
    	INV((n+1)>>1,a,b);//向上取整
    	int limit=1,l=0;
    	while(limit<(n<<1)) limit<<=1,l++;
    	for(int i=0;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    	for(int i=0;i<n;i++) c[i]=a[i];//a数组不能改变,所以赋值
    	for(int i=n;i<limit;i++) c[i]=0;//其余对答案没用;
    	NTT(c,limit,1),NTT(b,limit,1);
    	for(int i=0;i<limit;i++){
    		b[i]=(1LL*2*b[i]%p-1LL*b[i]*b[i]%p*c[i]%p+p)%p;
    	}
    	NTT(b,limit,-1);
    	for(int i=n;i<limit;i++) b[i]=0;
    }
    
    signed main(){
    	scanf("%lld",&n);
    	for(int i=0;i<n;i++) scanf("%lld",&a[i]);
    	INV(n,a,b);
    	for(int i=0;i<n;i++) cout<<(b[i]%p+p)%p<<" ";
    }
    
  • 相关阅读:
    pthread条件变量
    c++信号处理
    必杀技
    待飞日记(第四天和第五天)
    c++面试题总结(2)
    比起主流的30秒,10秒广告能获得2倍的效果
    c++面试题总结(1)
    待飞日记(第三天)
    static_cast, dynamic_cast, const_cast探讨
    c++一些问题总结
  • 原文地址:https://www.cnblogs.com/Aswert/p/14264278.html
Copyright © 2011-2022 走看看