zoukankan      html  css  js  c++  java
  • luogu P3803 【模板】多项式乘法(FFT)|NTT

    题目背景

    这是一道 FFT 模板题

    题目描述

    给定一个 n 次多项式 F(x),和一个 m 次多项式 G(x)。

    请求出 F(x) 和 G(x) 的卷积。

    输入格式

    第一行 2 个正整数 n,m。

    接下来一行 n+1 个数字,从低到高表示 F(x) 的系数。

    接下来一行 m+1 个数字,从低到高表示 G(x) 的系数。

    输出格式

    一行 n+m+1 个数字,从低到高表示 F(x)*G(x) 的系数。


    #include<cmath>
    #include<cstdio>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    using namespace std;
    const double Pi=acos(-1);
    #define db double
    #define maxn 1350000
    inline int read(){
      register char ch=0;
      while(ch<48||ch>57)ch=getchar();
      return ch-'0';
    }
    int n,m;
    struct CP{
    	CP (db xx=0,db yy=0){x=xx;y=yy;}
    	db x,y;
    	CP operator + (CP const &B)const
    	{return CP(x+B.x,y+B.y);}
    	CP operator - (CP const &B)const
    	{return CP(x-B.x,y-B.y);}
    	CP operator * (CP const &B)const
    	{return CP(x*B.x-y*B.y,x*B.y+y*B.x);}
    }f[maxn<<1];
    int tr[maxn<<1];
    inline void fft(CP *f,bool flag){
    	for(int i=0;i<n;i++)if(i<tr[i])swap(f[i],f[tr[i]]);
    	for(int p=2;p<=n;p<<=1){
    		int len=p>>1;
    		CP tG(cos(2*Pi/p),sin(2*Pi/p));
    		if(!flag)tG.y*=-1;
    		for(int k=0;k<n;k+=p){
    			CP buf(1,0);
    			for(int l=k;l<k+len;l++){
    				CP tt=buf*f[len+l];
    				f[len+l]=f[l]-tt;
    				f[l]=f[l]+tt;
    				buf=buf*tG;
    			}
    		}
    	}
    }
    signed main(){
    	scanf("%d%d",&n,&m);
    	for(int i=0;i<=n;i++)f[i].x=read();
    	for(int i=0;i<=m;i++)f[i].y=read();
    	for(m+=n,n=1;n<=m;n<<=1);
    	for(int i=0;i<n;i++)
    	tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
    	fft(f,1);
    	for(int i=0;i<n;i++)f[i]=f[i]*f[i];
    	fft(f,0);
    	for(int i=0;i<=m;i++)
    	printf("%d ",(int)(f[i].y/n/2+0.49));
    	return 0;
    }
    

    hzwer的代码:

    #include<bits/stdc++.h>
    #define N 262145
    #define pi acos(-1)
    using namespace std;
    typedef complex<double> E;
    int n,m,L;
    int R[N];
    E a[N],b[N];
    void fft(E *a,int f)
    {
    	for(int i=0;i<n;i++)if(i<R[i])swap(a[i],a[R[i]]);
    	for(int i=1;i<n;i<<=1)
    	{
    		E wn(cos(pi/i),f*sin(pi/i));
    		for(int p=i<<1,j=0;j<n;j+=p)
    		{
    			E w(1,0);
    			for(int k=0;k<i;k++,w*=wn)
    			{
    				E x=a[j+k],y=w*a[j+k+i];
    				a[j+k]=x+y;a[j+k+i]=x-y;
    			}
    		}
    	}
    }
    int main()
    {
    	scanf("%d%d",&n,&m);
    	for(int i=0,x;i<=n;i++)scanf("%d",&x),a[i]=x;
    	for(int i=0,x;i<=m;i++)scanf("%d",&x),b[i]=x;
    	m=n+m;for(n=1;n<=m;n<<=1)L++;
    	for(int i=0;i<n;i++)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
    	fft(a,1);fft(b,1);
    	for(int i=0;i<=n;i++)a[i]=a[i]*b[i];
    	fft(a,-1);
    	for(int i=0;i<=m;i++)
    		printf("%d ",(int)(a[i].real()/n+0.5));
    	return 0;
    }
    

    NTT写法:

    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    inline ll ty() {
        char ch = getchar(); ll x = 0, f = 1;
        while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
        while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
        return x * f;
    }
    const int _=4e6+10;
    const ll P=998244353,G=3,Gx=332748118;
    int N,M,r[_];
    ll A[_],B[_];
    inline ll ksm(ll a,ll b) {
        ll ret=1;
        for (;b;b>>=1){
            if (b & 1)ret=ret*a%P;
            a=a*a%P;
        }
        return ret;
    }
    inline void ntt(int lim, ll *a, int op) {
        for (int i=0;i<lim;++i)if(i<r[i])swap(a[i],a[r[i]]);
        for (int len=2;len<=lim;len<<=1){
            int mid=len >> 1;
            ll Wn=ksm(op==1?G:Gx,(P-1)/len);
            for (int i=0;i<lim;i+=len) {
                ll w=1;
                for (int j=0;j<mid;++j,w=(w*Wn)%P){
                    ll x=a[i+j],y=w*a[i+j+mid]%P;
                    a[i+j]=(x+y)%P;
                    a[i+j+mid]=(x-y+P)%P;
                }
            }
        }
    }
    int main() {
        N=ty(),M=ty();
        for(int i = 0; i <= N; ++i) A[i]=(ty() + P) % P;
        for(int i = 0; i <= M; ++i) B[i]=(ty() + P) % P;
        int lim = 1, k = 0;
        while (lim <= N + M) lim <<= 1, ++k;
        for (int i = 0; i < lim ; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
        ntt(lim, A, 1);
        ntt(lim, B, 1);
        for (int i = 0; i < lim; ++i) A[i] = (A[i] * B[i]) % P;
        ntt(lim, A, -1);
        ll invx = ksm(lim, P - 2);
        for (int i = 0; i <= N + M; ++i)
            printf("%lld ", (A[i] * invx) % P);
        return 0;
    }
    
  • 相关阅读:
    rapidjson 使用
    【设计模式】模板方法模式
    【设计模式】策略模式
    【设计模式】建造者模式
    【设计模式】享元模式
    /dev/sda1 contains a file system with errors,check forced.
    如何编写高效的Python的代码
    VsCode 调试 Python 代码
    Python 使用 pyinstaller 打包 代码
    初次使用git上传代码到github远程仓库
  • 原文地址:https://www.cnblogs.com/naruto-mzx/p/12093198.html
Copyright © 2011-2022 走看看