zoukankan      html  css  js  c++  java
  • 多项式做题笔记

    (早上好,笔记在注释里。)

    多项式卷积模板:

    FFT:

    #include<iostream>
    #include<cstdio>
    #include<cmath>
    using namespace std;
    const int N=3e6+10;
    double pi=acos(-1);
    int n,m;
    struct node{
        double x,y;
        node(double a=0,double b=0){
            x=a,y=b;
        }
        node operator + (node const &u) const{
            return node(x+u.x,y+u.y);
        }
        node operator - (node const &u) const{
            return node(x-u.x,y-u.y);
        }
        node operator * (node const &u) const{
            return node(x*u.x-y*u.y,y*u.x+x*u.y);
        }
    }f[N];
    int pos[N];
    void fft(node *f,bool flag){
        for(int i=0;i<n;i++){
            if(i<pos[i])swap(f[i],f[pos[i]]);
        }
        for(int p=2;p<=n;p<<=1){
            int len=p>>1;
            node fir(cos(2*pi/p),sin(2*pi/p));
            if(!flag)fir.y*=-1;
            for(int k=0;k<n;k+=p){
                node buf(1,0);
                for(int l=k;l<k+len;l++){
                    node tt=buf*f[len+l];
                    f[len+l]=f[l]-tt;
                    f[l]=f[l]+tt;
                    buf=buf*fir;
                }
            }
        }
    }
    int main()
    {
        scanf("%d%d",&n,&m);
        for(int i=0;i<=n;i++)scanf("%lf",&f[i].x);
        for(int i=0;i<=m;i++)scanf("%lf",&f[i].y);
        for(m+=n,n=1;n<=m;n<<=1);
        for(int i=0;i<n;i++){
            pos[i]=(pos[i>>1]>>1)|((i&1)?n>>1:0);
        }
        fft(f,1);
        for(int i=0;i<n;i++)f[i]=f[i]*f[i];
        fft(f,0);
        for(int i=0;i<=m;i++)printf("%d ",(int)(f[i].y/n/2+0.5));
        return 0;
    }
    //FFT比较丢精度,如果需要卷积的多项式系数值域相差太大,就会卡精度
    //三次变两次优化涉及的精度跨度上限更大,严重掉精度 
    【模板】多项式乘法(FFT)

    NTT:

    #include<iostream>
    #include<cstdio>
    using namespace std;
    const int N=3e6+10,mod=998244353,G=3;
    int n,m,pos[N];
    long long f[N],g[N],invn,invG;
    long long pw(long long x,long long k){
        long long num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }
    void ntt(long long *f,bool flag){
        for(int i=0;i<n;i++){
            if(i<pos[i])swap(f[i],f[pos[i]]);
        }
        for(int p=2;p<=n;p<<=1){
            int len=p>>1;
            long long fir=pw((flag?G:invG),(mod-1)/p);
            for(int i=0;i<n;i+=p){
                long long bur=1;
                for(int l=i;l<i+len;l++){
                    long long tt=bur*f[l+len]%mod;
                    f[l+len]=((f[l]-tt)%mod+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                    
                }
            }
        }
    }
    int main()
    {
        scanf("%d%d",&n,&m);
        for(int i=0;i<=n;i++)scanf("%lld",&f[i]);
        for(int i=0;i<=m;i++)scanf("%lld",&g[i]);
        for(m+=n,n=1;n<=m;n<<=1);
        for(int i=0;i<n;i++){
            pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0);
        }
        invn=pw(n,mod-2),invG=pw(G,mod-2);
        ntt(f,1),ntt(g,1);
        for(int i=0;i<n;i++){
            f[i]=f[i]*g[i]%mod;
        }
        ntt(f,0);
        for(int i=0;i<=m;i++){
            printf("%lld ",f[i]*invn%mod);
        }
        return 0;
    } 
    //数组需要开到2的幂次以上,不是两倍
    //第一个单位根从1开始
    //mod的大小开在最大系数以上防止模掉 
    P3803 【模板】多项式乘法(NTT)

    多项式求逆:

    #include<iostream>
    #include<cstdio>
    #define ll long long
    using namespace std;
    const int N=3e5+10,mod=998244353,G=3;
    int n,pos[N];
    ll a[N],b[N],c[N],invG;
    ll pw(ll x,ll k){
        ll num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }
    void ntt(ll *f,int n,bool flag){
        for(int i=0;i<n;i++)if(i<pos[i])swap(f[i],f[pos[i]]);
        for(int p=2;p<=n;p<<=1){
            int len=p>>1;
            ll fir=pw((flag?G:invG),(mod-1)/p);
            for(int k=0;k<n;k+=p){
                ll bur=1;
                for(int l=k;l<len+k;l++){
                    ll tt=f[l+len]*bur%mod;
                    f[len+l]=((f[l]-tt)%mod+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                }
            }
        }
    }                
    void getinv(int now,ll *a,ll*b){
        if(now==1){b[0]=pw(a[0],mod-2);return;}
        getinv((now+1)>>1,a,b);
        int goal=1;
        while(goal<(now<<1))goal<<=1;
        ll invn=pw(goal,mod-2);
        for(int i=0;i<goal;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(goal>>1):0);
        for(int i=0;i<now;i++)c[i]=a[i];
        for(int i=now;i<goal;i++)c[i]=0;
        ntt(c,goal,1),ntt(b,goal,1);
        for(int i=0;i<goal;i++)b[i]=((2ll-c[i]*b[i]%mod)%mod+mod)%mod*b[i]%mod;
        ntt(b,goal,0);
        for(int i=0;i<now;i++)b[i]=b[i]*invn%mod;
        for(int i=now;i<goal;i++)b[i]=0;
    }
    int main(){
        scanf("%d",&n);
        for(int i=0;i<n;i++)scanf("%lld",&a[i]);
        invG=pw(G,mod-2);
        getinv(n,a,b);
        for(int i=0;i<n;i++)printf("%lld ",b[i]);
        return 0;
    }
    //递归每一层本质对x^now取模,逆元数组b每次处理完要把被模掉的多余部分清空。 
    P4238 【模板】多项式乘法逆

    一些题目:

     快速傅里叶之二:

    #include<iostream>
    #include<cstdio>
    using namespace std;
    const int N=3e5+10,G=3;
    const long long mod=2281701377;
    int n,m,pos[N];
    long long a[N],b[N],invG,invn;
    long long pw(long long x,long long k){
        long long num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }
    void ntt(long long *f,bool flag){
        for(int i=0;i<n;i++){
            if(i<pos[i])swap(f[i],f[pos[i]]);
        }
        for(int p=2;p<=n;p<<=1){
            int len=p>>1;
            long long fir=pw((flag?G:invG),(mod-1)/p);
            for(int k=0;k<n;k+=p){
                long long bur=1;
                for(int l=k;l<k+len;l++){
                    long long tt=bur*f[l+len]%mod;
                    f[l+len]=((f[l]-tt)%mod+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                }
            }
        }
    }
    int main()
    {
        scanf("%d",&n);
        for(int i=0;i<n;i++)scanf("%lld%lld",&a[i],&b[i]);
        for(int i=1;i<=n/2;i++)swap(b[i-1],b[n-i]);
        for(m=n+n,n=1;n<=m-2;n<<=1);
        invn=pw(n,mod-2),invG=pw(G,mod-2);
        for(int i=1;i<n;i++){
            pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0);
        }
        ntt(a,1),ntt(b,1);
        for(int i=0;i<n;i++){
            a[i]=a[i]*b[i]%mod;
        }
        ntt(a,0);
        for(int i=m/2-1;i<=m-2;i++){
            printf("%lld
    ",a[i]*invn%mod);
        }
        return 0;
    } 
    快速傅里叶之二

    P3338 [ZJOI2014]力

    #include<iostream>
    #include<cstdio>
    #include<cmath>
    using namespace std;
    const int N=3e5+10;
    double pi=acos(-1);
    int n,m,pos[N];
    struct node{
        double x,y;
        node(double xx=0,double yy=0){
            x=xx,y=yy;
        }
        node operator + (node const &u) const{
            return node(x+u.x,y+u.y);
        }
        node operator - (node const &u) const{
            return node(x-u.x,y-u.y);
        }
        node operator * (node const &u) const{
            return node(x*u.x-y*u.y,y*u.x+x*u.y);
        }
    }f[N],g[N],f0[N];
    void fft(node *f,bool flag){
        for(int i=0;i<n;i++){
            if(i<pos[i])swap(f[i],f[pos[i]]);
        }
        for(int p=2;p<=n;p<<=1){
            int len=p>>1;
            node fir=node(cos(2*pi/p),sin(2*pi/p));
            if(!flag)fir.y=-fir.y;
            for(int k=0;k<n;k+=p){
                node bur=node(1,0);
                for(int l=k;l<len+k;l++){
                    node tt=bur*f[len+l];
                    f[len+l]=f[l]-tt;
                    f[l]=f[l]+tt;
                    bur=bur*fir;
                }
            }
        }
    }
    int main()
    {
        scanf("%d",&n);
        for(int i=1;i<=n;i++){
            scanf("%lf",&f[i].x);
            f0[n-i+1].x=f[i].x;
            g[i].x=1.0/i/i;//(i*i)炸int 
        }
        for(m=n+n,n=1;n<=m;n<<=1);
        for(int i=0;i<n;i++){
            pos[i]=(pos[i>>1]>>1)|((i&1)?n>>1:0);
        }
        fft(f,1),fft(g,1),fft(f0,1);
        for(int i=0;i<n;i++){
            f[i]=f[i]*g[i];
            f0[i]=f0[i]*g[i];
        }
        fft(f,0),fft(f0,0);
        for(int i=1;i<=m/2;i++){
            printf("%.3lf
    ",(f[i].x-f0[m/2-i+1].x)/n);
        }
        return 0;
    }
    //反转是基操 
    P3338 [ZJOI2014]力

    Tyvj1953 Normal:

    #include<iostream>
    #include<cstdio>
    #define ll long long
    using namespace std;
    const int N=200010,mod=998244353,G=3;
    int nn,n,root,siz[N],vis[N],f[N],sum,dis[N],pos[N];
    int ver[2*N],head[N],Next[2*N],tot,maxx;
    ll A[N],invG,invn,ans[N];
    long double ans1;
    void add(int x,int y){
        ver[++tot]=y;
        Next[tot]=head[x];
        head[x]=tot;
    }
    ll pw(ll x,ll k){
        ll num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }    
    void findroot(int x,int fa){
        siz[x]=1,f[x]=0;
        for(int i=head[x];i;i=Next[i]){
            int y=ver[i];
            if(vis[y]||y==fa)continue;
            findroot(y,x);
            siz[x]+=siz[y];
            f[x]=max(f[x],siz[y]);
        }
        f[x]=max(f[x],sum-siz[x]);
        if(f[x]<f[root])root=x;
    }
    void ntt(ll *f,bool flag){
        for(int i=0;i<nn;i++)if(i<pos[i])swap(f[i],f[pos[i]]);
        for(int p=2;p<=nn;p<<=1){
            int len=p>>1;
            ll fir=pw((flag?G:invG),(mod-1)/p);
            for(int k=0;k<nn;k+=p){
                ll bur=1;
                for(int l=k;l<k+len;l++){
                    ll tt=bur*f[l+len]%mod;
                    f[l+len]=((f[l]-tt)%mod+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                }
            }
        }
    }
    void getdis(int x,int fa){
        A[dis[x]]++;
        maxx=max(maxx,dis[x]);
        for(int i=head[x];i;i=Next[i]){
            int y=ver[i];
            if(vis[y]||y==fa)continue;
            dis[y]=dis[x]+1;
            getdis(y,x);
        }
    }
    void cal(int x,int lon,int val){
        dis[x]=lon;
        for(int i=0;i<nn;i++)A[i]=0,pos[i]=0;
        maxx=0;
        getdis(x,0);
        for(nn=1;nn<=(maxx*2);nn<<=1);
        for(int i=0;i<nn;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(nn>>1):0);
        invn=pw(nn,mod-2);
        ntt(A,1);
        for(int i=0;i<nn;i++)A[i]=A[i]*A[i]%mod;
        ntt(A,0);
        for(int i=0;i<nn;i++)ans[i]=(ans[i]+A[i]*invn*val%mod+mod)%mod;
    }
    void solve(int x){
        vis[x]=1;
        cal(x,0,1);
        for(int i=head[x];i;i=Next[i]){
            int y=ver[i];
            if(vis[y])continue;
            cal(y,1,-1);
            sum=siz[y];
            findroot(y,root=0);
            solve(root);
        }
    }    
    int main(){
        scanf("%d",&n);
        f[0]=mod;
        for(int i=1,x,y;i<n;i++){
            scanf("%d%d",&x,&y);
            x++,y++; 
            add(x,y),add(y,x);
        }
        invG=pw(G,mod-2);
        sum=n;
        findroot(1,0);
        solve(root);
        for(int i=0;i<n;i++){
            ans1+=(1.0/(i+1))*ans[i];
        }
        printf("%.4Lf",ans1);
        return 0;
    }
    //转化成每个点的贡献:每个点的贡献即为它在分治树上的深度
    //考虑一个点对的贡献:点x在点y计数时产生1的贡献,说明点y是x到y这条路径上被选出来的第一个点。
    //如果选了路径以外的点,对x和y的关系没有影响,它们在同一棵子树中。
    //如果选了路径上其它点,则x和y会被分到两个不同子树中,且都对选出来的点产生1的贡献。
    //x-y这条路径上每个点被选中的概率是相同的,所以(x,y)产生贡献的期望为1/(len(x,y)+1)(len+1即为路径上的点数)
    //统计所有长度的路径的数量即可,可以点分治+FFT在O(nlog^2n)的复杂度内求出
    //注意确保每次NTT之前都把边界卡在当前子树深度范围处,保证总复杂度正确
    Tyvj1953 Normal

    Triple:

    #include<iostream>
    #include<cstdio>
    using namespace std;
    const int N=150010,g=3;
    const long long mod=2281701377;
    int n,m,pos[N],cnt[N],maxx;
    long long F[N],G[N],H[N],invg,invn,ans[N],inv2,inv3;
    long long pw(long long x,long long k){
        long long num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }
    void ntt(long long *f,bool flag){
        for(int i=0;i<n;i++){
            if(i<pos[i])swap(f[i],f[pos[i]]);
        }
        for(int p=2;p<=n;p<<=1){
            int len=p>>1;
            long long fir=pw((flag?g:invg),(mod-1)/p);
            for(int k=0;k<n;k+=p){
                long long bur=1;
                for(int l=k;l<k+len;l++){
                    long long tt=bur*f[l+len]%mod;
                    f[l+len]=((f[l]-tt)%mod+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                }
            }
        }
    }
    int main()
    {
        scanf("%d",&n);
        for(int i=1,x;i<=n;i++){
            scanf("%d",&x);
            F[x]++;
            ans[x]++;
            cnt[x]++;
            maxx=max(maxx,x);
        }
        for(m=maxx*3,n=1;n<=m;n<<=1);
        for(int i=0;i<n;i++){
            pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0);
        }
        invg=pw(g,mod-2),invn=pw(n,mod-2),inv2=pw(2,mod-2);
        ntt(F,1);
        for(int i=0;i<n;i++){
            H[i]=F[i]*F[i]%mod;
            G[i]=H[i];
        }
        ntt(G,0);
        for(int i=0;i<=n;i++){
            G[i]=G[i]*invn%mod;
            if(i%2==0&&cnt[i/2])G[i]=((G[i]-cnt[i/2])%mod+mod)%mod;
            G[i]=G[i]*inv2%mod;
            ans[i]+=G[i];
        }
        ntt(G,1);
        for(int i=0;i<n;i++){
            G[i]=G[i]*F[i]%mod;
        }
        ntt(G,0);
        for(int i=0;i<n;i++){
            H[i]=H[i]*F[i]%mod;
        }
        ntt(H,0);
        for(int i=0;i<=n;i++){
            H[i]=H[i]*invn%mod,G[i]=G[i]*invn%mod;
            if(i%3==0&&cnt[i/3])H[i]=((H[i]-cnt[i/3])%mod+mod)%mod;
            H[i]=H[i]*invg%mod;
            ans[i]+=G[i]-H[i];
        }
        for(int i=0;i<n;i++){
            if(ans[i]){
                printf("%d %lld
    ",i,ans[i]);
            }
        }
        return 0;
    }
    //F+
    //(F*F(=H)-F(每个原数的平方项的系数减原数的数量))/2(=G)+
    //G*F-((H*F-F(每个原数的立方项的系数减原数的数量))/3)
    //注意FFT和IDFT在函数中的差别
    //注意除法用逆元
    //注意计算使质数mod的范围大于结果的值 
    
    //正确式子: 
    //先构造Ai为x指数的生成函数A(x)
    //再构造2Ai为指数的生成函数B(x)
    //再构造3Ai为指数的生成函数C(x)
    //A(x)+(A^2(x)-B(x))/2+(A^3(x)-3*A(x)*B(x)+2*C(x))/6
    Triple

    P4199 万径人踪灭:

    //两个多项式,A的项系数为同一位置是否为a,B则为是否为b(字符串预先拓展,中间加特殊字符#
    //A*A,B*B,每个位置的sum为两个多项式中i*2的项的系数相加 
    //每一项拓展成2^a[i]-1,各项相加(这就是全部的情况了) 
    //减去不合法的数目——回文子串的数目
    //manacher!
    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #define ll long long
    using namespace std;
    const int N=6e5+10,G=3;
    const long long mod=2281701377,mod0=1e9+7;
    char s[N],s0[N];
    int lens,m,n,pos[N];
    ll A[N],B[N],invG,invn,ans[N],sum,inv2,inv;
    ll pw(ll x,ll k){
        ll num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }
    ll pw0(ll x,ll k){
        ll num=1;
        while(k){
            if(k&1)num=num*x%mod0;
            x=x*x%mod0;
            k>>=1;
        }
        return num;
    }
    void ntt(ll *f,bool flag){
        for(int i=0;i<n;i++){
            if(i<pos[i])swap(f[i],f[pos[i]]);
        }
        for(int p=2;p<=n;p<<=1){
            int len=p>>1;
            ll fir=pw((flag?G:invG),(mod-1)/p);
            for(int k=0;k<n;k+=p){
                ll bur=1;
                for(int l=k;l<k+len;l++){
                    ll tt=f[len+l]*bur%mod;
                    f[len+l]=((f[l]-tt)%mod+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                }
            }
        }
    }
    ll lon[N];
    void manacher(){
        int right=0,pos=0;
        lon[0]=1;
        s[0]='#';
        s[m+1]='#';
        for(int i=1;i<=m;i++){
            if(right<i){
                lon[i]=1;
                while(s[i-lon[i]]==s[i+lon[i]]&&i-lon[i]>=0&&i+lon[i]<=m+1)lon[i]++;
                right=i+lon[i]-1;
                pos=i;
            }
            else{
                int j=2*pos-i;
                if(i+lon[j]-1>right)lon[i]=right-i+1;
                else if(i+lon[j]-1<right)lon[i]=lon[j];
                else{
                    lon[i]=lon[j];
                    while(s[i-lon[i]]==s[i+lon[i]]&&i-lon[i]>=0&&i+lon[i]<=m+1)lon[i]++;
                    right=i+lon[i]-1;
                    pos=i;
                }
            }
        }
    }
    int main()
    {
        scanf("%s",s0+1);
        lens=strlen(s0+1);
        for(int i=lens;i>=1;i--){
            s[i*2-1]=s0[i];
            s[i*2-2]='#';
            if(s0[i]=='a')A[i*2-1]=1;
            else B[i*2-1]=1;
        }
        for(m=lens*2-1,n=1;n<=2*m;n<<=1);
        for(int i=0;i<n;i++){
            pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0);
        }
        invn=pw(n,mod-2),invG=pw(G,mod-2);
        ntt(A,1),ntt(B,1);
        for(int i=0;i<n;i++){
            A[i]=A[i]*A[i]%mod;
            B[i]=B[i]*B[i]%mod;
        }
        ntt(A,0),ntt(B,0);
        for(int i=0;i<n;i++){
            A[i]=A[i]*invn%mod;
            B[i]=B[i]*invn%mod;
        }
        for(int i=1;i<=m;i++){
            if(s[i]=='a')A[i*2]=(A[i*2]+1)%mod;
            else if(s[i]=='b')B[i*2]=(B[i*2]+1)%mod;
        }
        inv=pw(2,mod-2);
        for(int i=0;i<n;i++){
            A[i]=A[i]*inv%mod;
            B[i]=B[i]*inv%mod;
        }
        inv2=pw0(2,mod0-2);
        for(int i=0;i<=m;i++){
            ans[i]=(A[i*2]+B[i*2])%mod0;
            ans[i]=((pw0(2,ans[i])-1)%mod0+mod0)%mod0;
            sum=(sum+ans[i])%mod0;
        }
        manacher();
        for(int i=1;i<=m;i++){
            if(lon[i]%2==1)lon[i]--;
            lon[i]=lon[i]*inv2%mod0;
            sum=((sum-lon[i])%mod0+mod0)%mod0;
        }
        printf("%lld
    ",sum);
        return 0;
     } 
    P4199 万径人踪灭

    P3321 [SDOI2015]序列统计:

    #include<iostream>
    #include<cstdio>
    #include<cmath>
    #include<cstring>
    #define ll long long
    using namespace std;
    const int N=1e5+10,G=3;
    const long long mod=1004535809;
    ll n,m,x,s,g,rec[N];
    ll pri[N],cnt,mm,nn,pos[N];
    ll A[N],invn,invG,ans;
    ll pw(ll x,ll k){
        ll num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }
    ll pw0(ll x,ll k){//!!
        ll num=1;
        while(k){
            if(k&1)num=num*x%m;
            x=x*x%m;
            k>>=1;
        }
        return num;
    }
    int check(ll x){
        for(int i=1;i<=cnt;i++){
            if(pw0(x,pri[i])==1)return 0;
        }
        return 1;
    }
    void ntt(ll *f,bool flag){
        for(int i=0;i<nn;i++){
            if(i<pos[i])swap(f[i],f[pos[i]]);
        }
        for(int p=2;p<=nn;p<<=1){
            int len=p>>1;
            ll fir=pw((flag?G:invG),(mod-1)/p);
            for(int k=0;k<nn;k+=p){
                ll bur=1;
                for(int l=k;l<len+k;l++){
                    ll tt=bur*f[len+l]%mod;
                    f[len+l]=((f[l]-tt)%mod+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                }
            }
        }
    }
    void ks(ll *A,ll k){
        ll B[N];
        memset(B,0,sizeof(B));
        for(int i=0;i<nn;i++)B[i]=A[i],A[i]=0;
        A[0]=1;
        while(k){
            if(k&1){
                ntt(A,1),ntt(B,1);
                for(int i=0;i<nn;i++){
                    A[i]=A[i]*B[i]%mod;
                }
                ntt(A,0),ntt(B,0);
                for(int i=0;i<nn;i++){
                    A[i]=A[i]*invn%mod;
                    B[i]=B[i]*invn%mod;
                    if(i>m-1){
                        A[i-(m-1)]=(A[i-(m-1)]+A[i])%mod;
                        A[i]=0;
                        B[i-(m-1)]=(B[i-(m-1)]+B[i])%mod;
                        B[i]=0;
                    }
                }
            }
            ntt(B,1);
            for(int i=0;i<nn;i++){
                B[i]=B[i]*B[i]%mod;
            }
            ntt(B,0);
            for(int i=0;i<nn;i++){
                B[i]=B[i]*invn%mod;
                if(i>m-1){
                    B[i-(m-1)]=(B[i-(m-1)]+B[i])%mod;
                    B[i]=0;
                }
            }
            k>>=1;
        }
    }
    int main()
    {
        scanf("%lld%lld%lld%lld",&n,&m,&x,&s);
        for(int i=2;i*i<=m-1;i++){
            if((m-1)%i==0){
                pri[++cnt]=i;
                if((i*i)!=(m-1))pri[++cnt]=(m-1)/i;
            }
        }
    //    pri[++cnt]=m-1;
        for(int i=2;i<=100;i++){
            if(check(i)){
                g=i;
                ll num=1;
                for(int j=1;j<=m-1;j++){
                    num=num*g%m;
                    rec[num]=j;
                }
                break;
            }
        }
        for(int i=1;i<=s;i++){
            ll xx;
            scanf("%lld",&xx);
            if(xx==0)continue;
            A[rec[xx%m]]++;
        }
        for(mm=m*2,nn=1;nn<=mm;nn<<=1);
        for(int i=0;i<nn;i++){
            pos[i]=(pos[i>>1]>>1)|((i&1)?(nn>>1):0);
        }
        invn=pw(nn,mod-2),invG=pw(G,mod-2);
        ks(A,n);
        printf("%lld
    ",A[rec[x]]);
        return 0;
    }
    //注意求原根的时候,快速幂取模不要和全局取模弄混
    //这里多项式相乘10^9之后数组是没法直接存下那么多项的
    //但是由于这题的特殊性质,第i项和第i-(m-1)项等价,于是可以把后面的加去前面 
    P3321 [SDOI2015]序列统计

    P4091 [HEOI2016/TJOI2016]求和:

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #define ll long long
    using namespace std;
    const int N=300010,G=3;
    const long long mod=998244353;
    int m,n,pos[N],mm;
    ll invn,invG,A[N],B[N],g[N],rec[N],inv[N],ans,sum;
    ll pw(ll x,ll k){
        ll num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }
    void work(){
        rec[0]=inv[0]=rec[1]=1;
        for(int i=2;i<=m;i++)rec[i]=rec[i-1]*i%mod;
        inv[m]=pw(rec[m],mod-2);
        for(int i=m-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%mod;
    }
    void ntt(ll *f,bool flag){
        for(int i=0;i<n;i++)if(i<pos[i])swap(f[i],f[pos[i]]);
        for(int p=2;p<=n;p<<=1){
            int len=p>>1;
            ll fir=pw((flag?G:invG),(mod-1)/p);
            for(int k=0;k<n;k+=p){
                ll bur=1;
                for(int l=k;l<len+k;l++){
                    ll tt=bur*f[l+len]%mod;
                    f[len+l]=((f[l]-tt)%mod+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                }
            }
        }
    }
    int main(){
        scanf("%d",&m);
        for(mm=2*m,n=1;n<=mm;n<<=1);
        invn=pw(n,mod-2),invG=pw(G,mod-2);
        work();
        for(int i=0;i<=m;i++){
            A[i]=(((i&1)?-1:1)+mod)%mod*inv[i]%mod;
            B[i]=(pw(i,m+1)-1)*pw(i-1,mod-2)%mod*inv[i]%mod;
        }
        B[0]=1,B[1]=m+1;
        for(int i=1;i<n;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0);
        ntt(A,1),ntt(B,1);
        for(int i=0;i<n;i++)A[i]=A[i]*B[i]%mod;
        ntt(A,0);
        for(int i=0;i<=m;i++)g[i]=rec[i]*A[i]%mod*invn%mod;
        sum=1;
        for(int i=0;i<=m;i++)ans=(ans+sum*g[i]%mod)%mod,sum=sum*2%mod;
        printf("%lld
    ",ans);
        return 0;
    }
    //注意区分变量名
    //推公式懒得打了,复习的时候不会的话(退役吧)看题解吧
    //第二类斯特林数递推公式:S(i,j)=S(i-1,j-1)+j*S(i-1,j)
    //含义:表示i个不同球放在j个相同盒子里,盒子不允许为空的方案数。
    //如果前面的球放在j-1个盒子里就新拿一个盒子 ,如果前面的球已经放了j个盒子就随便选一个放进去
    //(相关:排列组合问题——8种情况的球和盒子)
    //第二类斯特林数容斥原理公式:S(i,j)=1/j!*Σ(-1)^k*C(j,k)*(j-k)^i,(0<=k<=j)
    //含义:先考虑盒子不同的问题,枚举至少k个盒子为空,有C(j,k)种选盒子的方式
    //球随便放在剩下的盒子里的方案是(j-k)^i(这里仍然可能出现其它盒子为空,因为球随便放了) 
    //用容斥来计算恰好0个盒子为空的方案数,容斥系数是(-1)^k
    //因为实际上盒子是相同的,所以除去盒子全排列的方案数 
    //第二类斯特林数的性质:j^i=ΣS(i,k)*C(j,k)*k!,(0<=k<=j) 
    //含义:j^i即为i个不同球放在j个不同盒子里,盒子可以为空的方案数
    //枚举有多少盒子不为空,此时有C(j,k)种选盒子的方式,存在S(i,k)代表球放在这么多相同盒子里的方案数
    //由于盒子其实是不同的,所以乘上盒子全排列的方案数
    //附加:i个不同球放在j个不同盒子里且盒子不允许为空的方案数即为上式去掉枚举k个盒子不为空 
    P4091 [HEOI2016/TJOI2016]求和

    P4491 [HAOI2018]染色:

    #include<iostream>
    #include<cstdio>
    #define ll long long
    using namespace std;
    const int N=650010,G=3;
    const long long mod=1004535809;
    int n,m,s,limit,maxn,mm,nn,pos[N];
    ll w[N],rec[10000010],inv[10000010],f[N],g[N],h[N],invn,invG,ans;
    ll pw(ll x,ll k){
        ll num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }
    void work(){
        maxn=max(m,n);
        rec[0]=inv[0]=rec[1]=1;
        for(int i=2;i<=maxn;i++)rec[i]=rec[i-1]*i%mod;
        inv[maxn]=pw(rec[maxn],mod-2);
        for(int i=maxn-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%mod;
    }
    void ntt(ll *f,bool flag){
        for(int i=0;i<nn;i++)if(i<pos[i])swap(f[i],f[pos[i]]);
        for(int p=2;p<=nn;p<<=1){
            int len=p>>1;
            ll fir=pw((flag?G:invG),(mod-1)/p);
            for(int k=0;k<nn;k+=p){
                ll bur=1;
                for(int l=k;l<k+len;l++){
                    ll tt=bur*f[l+len]%mod;
                    f[l+len]=((f[l]-tt)%mod+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                }
            }
        }
    }
    int main(){
        scanf("%d%d%d",&n,&m,&s);
        for(int i=0;i<=m;i++){
            scanf("%lld",&w[i]);
        }
        work();
        limit=min(n/s,m);
        for(int i=0;i<=limit;i++){
            f[i]=rec[m]*inv[m-i]%mod*rec[n]%mod*inv[n-i*s]%mod;
            f[i]=f[i]*pw(inv[s],i)%mod*pw(m-i,n-i*s)%mod;
        }
        for(mm=2*limit,nn=1;nn<=mm;nn<<=1);
        for(int i=0;i<=limit;i++)g[i]=(((i&1)?-1:1)*inv[i]%mod+mod)%mod;
        for(int i=0;i<=limit;i++)h[limit-i]=g[i];
        for(int i=1;i<nn;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(nn>>1):0);
        invn=pw(nn,mod-2),invG=pw(G,mod-2);
        ntt(f,1),ntt(h,1);
        for(int i=0;i<nn;i++)f[i]=f[i]*h[i]%mod;
        ntt(f,0);
        for(int i=0;i<=limit;i++)ans=(ans+f[i+limit]*invn%mod*inv[i]%mod*w[i]%mod)%mod;
        printf("%lld",ans);
        return 0;
    }
    //先求至少i种染了s次的方案数f[i]
    //f[i]=C(m,i)*C(n,i*s)*(i*s)!/(s!)^i*(m-i)^(n-i*s)
    //含义是,在m种颜色里选择i种,在n个位置里占了哪s*i个,多重集的排列数,剩下的n-i*s个随便填剩下的颜色
    //然后求恰好i种染了s次的方案数g[i],用容斥处理
    //g[i]=Σ(-1)^(j-i)*C(j,i)*f[j] (j>=i)
    //含义是,容斥系数,加上至少i个,减去至少i+1个…每个f[j]会被多算C(j,i)次
    //把组合数拆开,化简得
    //g[i]=(Σ(-1)^(j-i)/(j-i)!*f[j]*j!)/i!
    //常见套路,设A[i]=(-1)^i/i!,B[i]=f[i]*i!,反转A数组,对A*B进行NTT,在i+n(设反转总长为n)处寻找i的答案 
    P4491 [HAOI2018]染色

    P4841 [集训队作业2013]城市规划:

    #include<iostream>
    #include<cstdio>
    #define ll long long
    using namespace std;
    const int N=500010,G=3;
    const long long mod=1004535809;
    int n,pos[N],nn;
    ll inv[N],rec[N],invG,B[N],C[N],D[N],c[N],invnn;
    ll pw(ll x,ll k){
        ll num=1;
        while(k){
            if(k&1)num=num*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return num;
    }
    void work(){
        inv[0]=rec[0]=rec[1]=1;
        for(int i=2;i<=n;i++)rec[i]=rec[i-1]*i%mod;
        inv[n]=pw(rec[n],mod-2);
        for(int i=n-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%mod;
    }
    void ntt(ll *f,int n,bool flag){
        for(int i=0;i<n;i++)if(i<pos[i])swap(f[i],f[pos[i]]);
        for(int p=2;p<=n;p<<=1){
            int len=p>>1;
            ll fir=pw((flag?G:invG),(mod-1)/p);
            for(int k=0;k<n;k+=p){
                ll bur=1;
                for(int l=k;l<k+len;l++){
                    ll tt=f[l+len]*bur%mod;
                    f[l+len]=(f[l]-tt+mod)%mod;
                    f[l]=(f[l]+tt)%mod;
                    bur=bur*fir%mod;
                }
            }
        }
    } 
    void getinv(int now,ll *a,ll *b){
        if(now==1){b[0]=pw(a[0],mod-2);return;}
        getinv((now+1)>>1,a,b);
        int goal=1;
        while(goal<(now<<1))goal<<=1;
        ll invn=pw(goal,mod-2);
        for(int i=0;i<goal;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(goal>>1):0);
        for(int i=0;i<now;i++)c[i]=a[i];
        for(int i=now;i<goal;i++)c[i]=0;
        ntt(b,goal,1),ntt(c,goal,1);
        for(int i=0;i<goal;i++)b[i]=((2ll-b[i]*c[i]%mod)+mod)%mod*b[i]%mod;
        ntt(b,goal,0);
        for(int i=0;i<now;i++)b[i]=b[i]*invn%mod;
        for(int i=now;i<goal;i++)b[i]=0;
    }
    int main(){
        scanf("%d",&n);
        invG=pw(G,mod-2);
        work();
        for(int i=0;i<=n;i++)B[i]=pw(2,1ll*i*(i-1)/2)*inv[i]%mod;
        for(int i=0;i<=n;i++)C[i]=pw(2,1ll*i*(i-1)/2)*inv[i-1]%mod;
        getinv(n,B,D);
        for(nn=1;nn<=(n*2);nn<<=1);
        for(int i=0;i<nn;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(nn>>1):0);
        ntt(C,nn,1),ntt(D,nn,1);
        for(int i=0;i<nn;i++)C[i]=C[i]*D[i]%mod;
        ntt(C,nn,0);
        invnn=pw(nn,mod-2);
        printf("%lld
    ",C[n]*rec[n-1]%mod*invnn%mod);
        return 0;
    }
    //考虑用总图数减去不连通的图的数量
    //总图数即2^C(n,2),任选两个点即为一条边,每条边有选或不选两种选择
    //求不连通的图的数量:枚举现在已经可以确定的联通的图的大小,其它点随意安排。
    //钦定1号点一直在联通的图中,因为包含1号点的连通块的大小一直在变化,所以枚举出来的所有子情况不重复
    //于是设f[i]为i个点满足题目条件的方案数:
    //f[i]=2^C(i,2)-ΣC(n-1,j-1)* f[j]*2^C(i-j,2),(1<=j<=i-1) 
    //转化式子,拆开不是指数的组合数,先两边同除以(i-1)!,移项,设0!=1来把枚举变成从1到i
    //然后把单项式分类成未知数形式相似的几部分,即卷积经典形式 
    //这时可以看出,设A[i]=f[i]/(i-1)!,B[i]=2^(i,2)/i!,C[i]= 2^(i,2)/(i-1)!
    //A*B=C,题目要求A
    //A=C*B^-1,对B多项式求逆元即可 
    P4841 [集训队作业2013]城市规划

    持续补完。

    对自己的记性没有太大指望。

  • 相关阅读:
    多态的理解
    成员变量与实例变量&成员方法与构造方法&构造代码块和静态代码块&静态与非静态&重写与重载的区别
    Java的基本数据类型和基本数据类型之间的转换
    如何使float精确两位小数或多位小数
    Servlet程序的生命周期
    配置axios全局拦截器
    SpringCloud-Sentinel实现服务限流、熔断、降级,整合Nacos实现持久化
    Nginx+Lua OpenResty环境搭建
    Java-线程池面试题
    Rabbitmq死信队列
  • 原文地址:https://www.cnblogs.com/chloris/p/12072717.html
Copyright © 2011-2022 走看看