zoukankan      html  css  js  c++  java
  • FFT&&NTT&&相关

    FFT 快速计算多项式乘法 

    bzoj3527 力

    题目大意:给定qi,求ei=sigma(j<i)qj/(i-j)^2-sigma(j>i)qj/(i-j)^2。

    思路:画个表格能发现两个三角都是可以卷积的,要求qj*1/(i-j)^2累加到ei上,但是右上角的部分要倒两次,然后就是fft了。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<cmath>
    #include<algorithm>
    #define LD double
    #define N 1000005
    using namespace std;
    struct use{
        LD r,i;
        void init(LD rr,LD ii){r=rr;i=ii;};
        use operator+(const use&x){return (use){r+x.r,i+x.i};}
        use operator-(const use&x){return (use){r-x.r,i-x.i};}
        use operator*(const use&x){return (use){r*x.r-i*x.i,r*x.i+x.r*i};}
    }a[N],b[N],ai[N],c[N];
    LD qi[N],ans[N];int up,l,rev[N]={0},ci[N]={0};
    LD sqr(LD x){return x*x;}
    void fft(use *a,int f){
        int i,j,k;use w,wn,x,y;
        for (i=0;i<up;++i) ai[i]=a[rev[i]];
        for (i=0;i<up;++i) a[i]=ai[i];
        for (i=2;i<=up;i<<=1){
            wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
            for (j=0;j<up;j+=i){
                w.init(1.,0.);
                for (k=j;k<j+i/2;++k){
                    x=a[k];y=a[k+i/2]*w;
                    a[k]=x+y;a[k+i/2]=x-y;
                    w=w*wn;
                }
            }
        }if (f==-1)
            for (i=0;i<up;++i) a[i].r/=up*1.;
    }
    int main(){
        int i,j,n;scanf("%d",&n);
        for (i=0;i<n;++i) scanf("%lf",&qi[i]);
        for (l=0,up=1;up<n;up<<=1,++l);up<<=1;++l;
        for (i=0;i<up;++i){
            int ll=0;
            for (j=i;j;j>>=1) ci[++ll]=j&1;
            for (j=1;j<=l;++j) rev[i]=(rev[i]<<1)|ci[j];
        }for (i=0;i<n;++i) a[i].init(qi[i],0.);
        for (i=1;i<n;++i) b[i].init(1./sqr((LD)i),0.);
        fft(a,1);fft(b,1);
        for (i=0;i<up;++i) c[i]=a[i]*b[i];
        fft(c,-1);for (i=0;i<n;++i) ans[i]=c[i].r;
        memset(a,0,sizeof(a));memset(b,0,sizeof(b));
        for (i=0;i<n;++i) a[i].init(qi[n-1-i],0.);
        for (i=1;i<n;++i) b[i].init(1./sqr((LD)i),0.);
        fft(a,1);fft(b,1);
        for (i=0;i<up;++i) c[i]=a[i]*b[i];
        fft(c,-1);for (i=0;i<n;++i) ans[i]-=c[n-1-i].r;
        for (i=0;i<n;++i) printf("%.9f
    ",ans[i]);
    }
    View Code

    codechef COUNTARI

    题目大意:给定n个数,求数列中i<j<k且ai、aj、ak呈等差数列的个数。

    思路:分块+fft。三个在一个块内的可以len^2,两个在块内一个在外面的也可以len^2,中间点在块内其他在两边的可以fft。

    注意:double强转longlong的时候是下取整,所以应该+0.5。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<cmath>
    #include<algorithm>
    #define N 100005
    #define up 30005
    #define LL long long
    #define LD double
    using namespace std;
    struct use{
        LD r,i;
        void init(LD rr,LD ii){r=rr;i=ii;}
        use operator+(const use&x){return(use){r+x.r,i+x.i};}
        use operator-(const use&x){return(use){r-x.r,i-x.i};}
        use operator*(const use&x){return(use){r*x.r-i*x.i,r*x.i+x.r*i};}
    }a[N],b[N],c[N],A[N];
    int ai[N],rev[N],en[N]={0},uu,l;
    LL c1[N]={0LL},c2[N]={0LL},cnt[up]={0LL};
    void fft(use *a,int f){
        int i,j,k;use w,wn,x,y;
        for (i=0;i<uu;++i) A[i]=a[rev[i]];
        for (i=0;i<uu;++i) a[i]=A[i];
        for (i=2;i<=uu;i<<=1){
            wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
            for (j=0;j<uu;j+=i){
                w.init(1.,0.);
                for (k=j;k<j+i/2;++k){
                    x=a[k];y=w*a[k+i/2];
                    a[k]=x+y;a[k+i/2]=x-y;
                    w=w*wn;
                }
            }
        }if (f==-1) for (i=0;i<uu;++i) a[i].r/=1.*uu;
    }
    LL calc(int x){
        int i,j;LL ans=0LL;
        for (i=0;i<uu;++i) a[i].init(c1[i],0.);
        for (i=0;i<uu;++i) b[i].init(c2[i],0.);
        fft(a,1);fft(b,1);
        for (i=0;i<uu;++i) c[i]=a[i]*b[i];
        fft(c,-1);
        for (i=en[x-1]+1;i<=en[x];++i) ans+=(LL)(c[2*ai[i]].r+0.5);
        return ans;}
    int main(){
        int n,i,j,k,ci,len,bl;LL ans=0LL;
        scanf("%d",&n);len=2000;bl=(n-1)/len+1;
        for (uu=1,l=0;uu<up;uu<<=1,++l);uu<<=1;++l;
        for (i=0;i<uu;++i){
            for(ci=0,j=i;j;j>>=1) en[++ci]=j&1;
            for(j=1;j<=l;++j) rev[i]=(rev[i]<<1)|en[j];
        }for (i=1;i<=n;++i){
            en[(i-1)/len+1]=i;
            scanf("%d",&ai[i]);
            ++c2[ai[i]];
        }for (i=1;i<=bl;++i){
            for (j=en[i-1]+1;j<=en[i];++j) --c2[ai[j]];
            for (j=en[i-1]+1;j<=en[i];++j){
                for (k=en[i];k>j;--k){
                    ci=ai[k]*2-ai[j];
                    if (ci>0&&ci<up) ans+=cnt[ci]+c2[ci];
                    ++cnt[ai[k]];
                    ci=ai[j]*2-ai[k];
                    if (ci>0&&ci<up) ans+=c1[ci];
                }for (k=j+1;k<=en[i];++k) --cnt[ai[k]];
            }ans+=calc(i);
            for (j=en[i-1]+1;j<=en[i];++j) ++c1[ai[j]];
        }printf("%I64d
    ",ans);
    }
    View Code

    codechef PRIMEDST

    题目大意:求树上距离为质数的点对的概率。

    思路:点分+fft。求距离为k的点对的时候用点分,现在这个k是所有质数,所以可以fft一下。注意有些数组不能清零防止tle;fft的上界可以根据每次的大小进行更改。(太久没写点分结果点分都写残了)

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cmath>
    #define N 150000
    #define M 50005
    #define LL long long
    #define LD double
    using namespace std;
    struct use{
        LD r,i;
        void init(LD rr,LD ii){r=rr;i=ii;}
        use operator+(const use&x){return (use){r+x.r,i+x.i};}
        use operator-(const use&x){return (use){r-x.r,i-x.i};}
        use operator*(const use&x){return (use){r*x.r-i*x.i,r*x.i+i*x.r};}
    }a[N],b[N],c[N],ai[N];
    int point[N]={0},next[N]={0},en[N]={0},mn,mx,rt,tot=0,siz[N],rev[N],di[100],
        prime[N]={0},ci[M]={0},up,l,ccc=0;
    bool vi[N]={false},flag[N]={false}; LL ans=0LL;
    void add(int u,int v){
        next[++tot]=point[u];point[u]=tot;en[tot]=v;
        next[++tot]=point[v];point[v]=tot;en[tot]=u;}
    void shai(int n){
        int i,j;
        for (i=2;i<=n;++i){
            if (!flag[i]) prime[++prime[0]]=i;
            for (j=1;j<=prime[0]&&i*prime[j]<n;++j){
                flag[i*prime[j]]=true;
                if (i%prime[j]==0) break;
            }
        }
    }
    void fft(use *a,int f){
        int i,j,k;use w,wn,x,y;
        for (i=0;i<up;++i) ai[i]=a[rev[i]];
        for (i=0;i<up;++i) a[i]=ai[i];
        for (i=2;i<=up;i<<=1){
            wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
            for (j=0;j<up;j+=i){
                w.init(1.,0.);
                for (k=j;k<j+i/2;++k){
                    x=a[k];y=w*a[k+i/2];
                    a[k]=x+y;a[k+i/2]=x-y;
                    w=w*wn;
                }
            }
        }if (f==-1) for (i=0;i<up;++i) a[i].r/=up*1.;
    }
    void grt(int u,int f,int nn){
        int i,v,ms=0;siz[u]=1;
        for (i=point[u];i;i=next[i]){
            if (vi[v=en[i]]||v==f) continue;
            grt(v,u,nn);ms=max(ms,siz[v]);
            siz[u]+=siz[v];
        }ms=max(ms,nn-siz[u]);
        if (ms<=mn){mn=ms;rt=u;}
    }
    void dfs(int u,int f,int de){
        int i,v;siz[u]=1;
        ++ci[de];mx=max(mx,de);
        for (i=point[u];i;i=next[i]){
            if (vi[v=en[i]]||v==f) continue;
            dfs(v,u,de+1);siz[u]+=siz[v];
        }
    }
    LL calc(int u,int de){
        int i,j,v;LL cnt=0LL;
        for (i=0;i<=mx;++i) ci[i]=0;
        memset(di,0,sizeof(di));
        mx=0;dfs(u,0,de);mx+=1;
        for (up=1,l=0;up<mx;up<<=1,++l);up<<=1;++l;
        for (i=0;i<up;++i){
            rev[i]=0;
            for (v=0,j=i;j;j>>=1) di[++v]=j&1;
            for (j=1;j<=l;++j) rev[i]=(rev[i]<<1)|di[j];
        }for (i=0;i<up;++i){
            v=(i>=M ? 0 : ci[i]);
            a[i].init(v*1.,0.);b[i].init(v*1.,0.);
        }fft(a,1);fft(b,1);
        for (i=0;i<up;++i) c[i]=a[i]*b[i];
        fft(c,-1);
        for (i=1;i<=prime[0]&&prime[i]<up;++i) cnt+=(LL)(c[prime[i]].r+0.5);
        return cnt;}
    void work(int u){
        int i,v;vi[u]=true;ans+=calc(u,0);
        for (i=point[u];i;i=next[i]){
          if (vi[v=en[i]]) continue;
          ans-=calc(v,1);
          grt(v,u,mn=siz[v]);work(rt);
        }
    }
    int main(){
        int n,i,u,v;LL cc;scanf("%d",&n);
        for (i=1;i<n;++i){scanf("%d%d",&u,&v);add(u,v);}
        grt(1,0,mn=n);shai(N);cc=(LL)n*((LL)n-1LL);
        work(rt);printf("%.9f
    ",(LD)ans*1./(LD)cc);
    }
    View Code

    bzoj3513 idiots

    题目大意:给定n个木棍,问能构成三角形的概率。(木棍长度<=2*10^5)

    思路:较短的两根的和<=第三根就是不符合的,木棍长度比较小,可以用fft,计算两个的和为x的木棍对数,对于长度为y的,x<=y的对数都是不满足的,但长度为x的对数中除了同一木棍选两次的统计了一次,其他的都统计了两次,所以要相应的减去。最后用(总的-不合法的)/总的就是答案了。

    注意:(1)fft清数组的时候,求rev的时候利用的保存二进制的数组也要清零;

       (2)统计答案的时候要注意减掉那些不合法的。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cmath>
    #define N 600005
    #define LD double
    #define LL long long
    using namespace std;
    struct use{
        LD r,i;
        void init(LD rr,LD ii){r=rr;i=ii;}
        use operator +(const use&x){return (use){r+x.r,i+x.i};}
        use operator -(const use&x){return (use){r-x.r,i-x.i};}
        use operator *(const use&x){return (use){r*x.r-i*x.i,r*x.i+i*x.r};}
    }a[N],c[N],ai[N];
    int rev[N],up,sm[N],cc[N];
    LL getc(LL n){return n*(n-1LL)*(n-2LL)/6LL;}
    int in(){
        char ch=getchar();int x=0;
        while(ch<'0'||ch>'9') ch=getchar();
        while(ch>='0'&&ch<='9'){
            x=x*10+ch-'0';ch=getchar();
        }return x;}
    void fft(use *aa,int f){
        int i,j,k;use x,y,wn,w;
        for (i=0;i<up;++i) ai[i]=aa[rev[i]];
        for (i=0;i<up;++i) aa[i]=ai[i];
        for (i=2;i<=up;i<<=1){
            wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
            for (j=0;j<up;j+=i){
                w.init(1.,0.);
                for (k=j;k<j+i/2;++k){
                    x=aa[k];y=aa[k+i/2]*w;
                    aa[k]=x+y;aa[k+i/2]=x-y;
                    w=w*wn;
                }
            }
        }if (f<0) for (i=0;i<up;++i) aa[i].r/=1.*up;
    }
    int main(){
        int n,i,j,x,mx=0,l=0,t;LL ci,ans;
        t=in();
        while(t--){
            n=in();mx=0;ans=0LL;
            memset(sm,0,sizeof(sm));
            for (i=1;i<=n;++i){
                x=in();++sm[x];
                mx=max(mx,x);
            }++mx;
            for (l=0,up=1;up<mx;up<<=1,++l);up<<=1;++l;
            for (i=0;i<=l;++i) cc[i]=0;
            for (i=0;i<up;++i){
                for (j=i,cc[0]=0;j;j>>=1) cc[++cc[0]]=j&1;
                for (rev[i]=0,j=1;j<=l;++j) rev[i]=(rev[i]<<1)|cc[j];
            }for (i=0;i<mx;++i){
                a[i].init((LD)sm[i],0.);
                if (i) sm[i]+=sm[i-1];
            }for (;i<up;++i){
                sm[i]+=sm[i-1];
                a[i].init(0.,0.);
            }fft(a,1);
            for (i=0;i<up;++i) c[i]=a[i]*a[i];
            fft(c,-1);
            for (ci=0LL,i=0;i<up;++i){
                ci+=(LL)(c[i].r+0.5);
                ans+=(ci-(LL)sm[i/2])/2LL*(LL)(sm[i]-sm[i-1]);
            }printf("%.7f
    ",1.-(LD)ans*1./(LD)getc((LL)n));
        }
    }
    View Code

    bzoj4503 两个串(!!!)

    题目大意:给定s1、s2,s2中有?可以匹配任何小写字母,问s2在s1中出现几次、出现的位置。

    思路:考虑一种hash方法:如果没有?,(s2-s1)^2=0的段是s2=s1的段,有了?,可以把?看作0,其他字母是1~26,s2*(s2-s1)^2=0的是匹配段,n比较大,把s2倒过来,用fft计算,在合法区间内取出值为0的就是这一段的结尾了。

    注意:点值表达式是可以乘和加的。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cmath>
    #define LD double
    #define N 2000005
    using namespace std;
    struct use{
        LD u,i;
        void init(LD x,LD y){u=x;i=y;}
        use operator+(const use&x)const{return (use){u+x.u,i+x.i};}
        use operator-(const use&x)const{return (use){u-x.u,i-x.i};}
        use operator*(const use&x)const{return (use){u*x.u-i*x.i,u*x.i+i*x.u};}
    }a[N],b[N],c[N],aa[N];
    char s1[N],s2[N];
    int l1,l2,up,l,rev[N]={0},ai[N]={0};
    int idx(char c){return (c=='?' ? 0 : c-'a'+1);}
    int sqr(int x){return x*x;}
    void fft(use *a,int f){
        int i,j,k;use wn,w,x,y;
        for (i=0;i<up;++i) aa[i]=a[rev[i]];
        for (i=0;i<up;++i) a[i]=aa[i];
        for (i=2;i<=up;i<<=1){
            wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i));
            for (j=0;j<up;j+=i){
                w.init(1.,0.);
                for (k=j;k<j+i/2;++k){
                    x=a[k];y=w*a[k+i/2];
                    a[k]=x+y;a[k+i/2]=x-y;
                    w=w*wn;
                }
            }
        }if (f==-1) for (i=0;i<up;++i) a[i].u/=1.*up;
    }
    int main(){
        int i,j,k,ans=0;LD sm=0.;
        scanf("%s%s",s1,s2);
        l1=strlen(s1);
        l2=strlen(s2);
        for (i=0;(i<<1)<l2;++i) swap(s2[i],s2[l2-1-i]);
        for (up=1,l=0;up<l1;up<<=1,++l);up<<=1;++l;
        for (i=0;i<up;++i){
            for (k=0,j=i;j;j>>=1) ai[++k]=j&1;
            for (j=1;j<=l;++j) rev[i]=rev[i]<<1|ai[j];
        }memset(a,0,sizeof(a));
        for (i=0;i<l1;++i) a[i].init(sqr(idx(s1[i])),0.);
        memset(b,0,sizeof(b));
        for (i=0;i<l2;++i){
            b[i].init(idx(s2[i]),0.);
            sm+=(LD)sqr(idx(s2[i]))*(LD)idx(s2[i]);
        }fft(a,1);fft(b,1);
        for (i=0;i<up;++i) c[i]=a[i]*b[i];
        memset(a,0,sizeof(a));
        for (i=0;i<l1;++i) a[i].init(idx(s1[i]),0.);
        memset(b,0,sizeof(b));
        for (i=0;i<l2;++i) b[i].init(sqr(idx(s2[i])),0.);
        fft(a,1);fft(b,1);
        for (i=0;i<up;++i) c[i]=c[i]-(a[i]*b[i])-(a[i]*b[i]);
        fft(c,-1);
        for (i=l2-1;i<l1;++i)
            if ((int)(c[i].u+sm+0.5)==0) ++ans;
        printf("%d
    ",ans);
        for (i=l2-1;i<l1;++i)
            if ((int)(c[i].u+sm+0.5)==0) printf("%d
    ",i-l2+1);
    }
    View Code

    bzoj3160万径人踪灭

    题目大意:给出一个只有ab的串,求满足:1)位置和字符都关于某个轴回文;2)中间存在空位的子串的个数。

    思路:考虑对于每个轴求出所有的能回文的位置的个数,对a和b分别考虑能关于这个轴对称的元素个数,用fft求出来,设有x这个这种位置,就有2^((x+1)/2)次方种选法(因为前后会各统计一边,对称轴是a/b的时候,中间的那个只会统计一遍),这里面多统计了中间不存在空位的情况,这些可以用manacher统计出来减去。

    注意:平方的话,只有一个数组的项是要单独用前缀和更新的。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cmath>
    #define N 400005
    #define LD double
    #define LL long long
    #define p 1000000007LL
    using namespace std;
    struct use{
        LD x,y;
        void init(LD xx,LD yy){x=xx;y=yy;}
        use operator+(const use&a)const{return (use){x+a.x,y+a.y};}
        use operator-(const use&a)const{return (use){x-a.x,y-a.y};}
        use operator*(const use&a)const{return (use){x*a.x-y*a.y,x*a.y+y*a.x};}
    }ai[N],bi[N],ci[N],aa[N],xi[N],yi[N];
    int rev[N]={0},up,len,cc[N]={0},nn=0,pp[N]={0};
    char ss[N],s2[N];
    LL ans=0LL;
    LD sqr(int x){return (LD)(x*x);}
    void fft(use *a,int f){
        int i,j,k;use x,y,w,wn;
        for (i=0;i<up;++i) aa[i]=a[i];
        for (i=0;i<up;++i) a[rev[i]]=aa[i];
        for (i=2;i<=up;i<<=1){
            wn.init(cos(2.*M_PI/i),f*sin(2.*M_PI/i));
            for (j=0;j<up;j+=i){
                w.init(1.,0.);
                for (k=j;k<j+i/2;++k){
                    x=a[k];y=w*a[k+i/2];
                    a[k]=x+y;a[k+i/2]=x-y;
                    w=w*wn;
                }
            }
        }if (f==-1) for (i=0;i<up;++i) a[i].x/=(LD)up*1.;
    }
    LL mi(LL x,int y){
        LL a=1LL;
        for (;y;y>>=1){
            if (y&1) a=a*x%p;
            x=x*x%p;
        }return (a+p-1LL)%p;}
    void add(LL &x,LL y){x=((x-y)%p+p)%p;}
    void mana(){
        int i,mx,id;
        for (mx=0,i=1;i<nn;++i){
            if (mx>i) pp[i]=min(pp[2*id-i],mx-i);
            else pp[i]=1;
            for (;s2[i-pp[i]]==s2[i+pp[i]];++pp[i]);
            if (pp[i]+i>mx){mx=pp[i]+i;id=i;}
            add(ans,pp[i]>>1);
        }
    }
    int main(){
        int i,j,n;scanf("%s",ss);
        n=strlen(ss);
        for(up=1,len=0;up<n;up<<=1,++len);up<<=1;++len;
        for (i=0;i<up;++i){
            cc[0]=0;
            for (j=i;j;j>>=1) cc[++cc[0]]=j&1;
            for (j=1;j<=len;++j) rev[i]=(rev[i]<<1)|cc[j];
        }memset(ai,0,sizeof(ai));
        memset(bi,0,sizeof(bi));
        for (i=0;i<n;++i){
            ai[i].init(sqr(ss[i]=='a'),0.);
            bi[i].init(sqr(ss[i]=='a'),0.);
        }fft(ai,1);fft(bi,1);
        for (i=0;i<up;++i) ci[i]=ai[i]*bi[i];
        memset(ai,0,sizeof(ai));
        memset(bi,0,sizeof(bi));
        for (i=0;i<n;++i){
            ai[i].init((LD)(ss[i]!='a'),0.);
            bi[i].init((LD)(ss[i]!='a'),0.);
        }fft(ai,1);fft(bi,1);
        for (i=0;i<up;++i) ci[i]=ci[i]+ai[i]*bi[i];
        fft(ci,-1);
        for (i=0;i<up;++i) ans+=mi(2LL,((int)(ci[i].x+0.5)+1)>>1);
        for (i=0;i<n;++i){s2[nn++]='c';s2[nn++]=ss[i];}
        s2[nn++]='c';s2[nn++]='d';
        mana();printf("%I64d
    ",ans);
    }
    View Code

    NTT 快速计算带mod的多项式乘法

    bzoj3992 序列统计

    题目大意:给定一个大小为|S|的集合S,求长度为n的乘积%m为x的排列个数(modP)。

    思路:ntt+原根。O(nm^2)的暴力dp,可以用倍增的思想优化到O(m^2logn),但这样不能优化掉m^2。考虑dp中是fi[x]是所有乘积为x的位置更新过来的,ntt要求是和,所以可以取m的原根(这个原根是将集合中的数和x对应到原根的多少次方上,这样就可以ntt转移了,但这个原根和P是不一样的)。

    ntt和fft类似,因为mod,所以可以直接用整数类型存储,但wn的求法略有不同。

    判断m原根的方法直接枚举原根x,如果x的m-1所有因子次方!=1就是原根了。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define N 40005
    #define P 1004535809LL
    #define G 3LL
    #define LL long long
    using namespace std;
    LL aa[N]={0LL},ai[N],nup,c[N]={0LL},bi[N],ci[N];
    int s[N],up,l,m,po[N]={0},num[N],rev[N]={0};
    LL mi(LL x,LL y,LL p){
        if (y==0) return 1LL;
        if (y==1) return x%p;
        LL mm=mi(x,y/2,p);
        if (y%2) return mm*mm%p*x%p;
        else return mm*mm%p;}
    bool judge(int x){
        for (int i=2;i*i<=m;++i)
            if ((m-1)%i==0&&mi((LL)x,(LL)(m-1)/i,m)==1) return false;
        return true;}
    int find(){
        int i;if (m==2) return 1;
        for (i=2;!judge(i);++i);
        return i;}
    void pre(){
        int i,j,k,g;
        for (up=1,l=0;up<2*m;up<<=1,++l);up<<=1;++l;
        for (i=0;i<up;++i){
            for (k=0,j=i;j;j>>=1) po[++k]=j&1;
            for (j=1;j<=l;++j) rev[i]=(rev[i]<<1)|po[j];
        }g=find();
        for (num[0]=1,po[1]=0,i=1;i<m-1;++i){
            num[i]=(int)((LL)num[i-1]*(LL)g%m);
            po[num[i]]=i;
        }nup=mi(up,P-2,P);}
    void ntt(LL *a,int f){
        int i,j,k;LL w,wn,x,y;
        for (i=0;i<up;++i) ai[i]=a[rev[i]];
        for (i=0;i<up;++i) a[i]=ai[i];
        for (i=2;i<=up;i<<=1){
            wn=mi(G,(f==1 ? (P-1)/i : P-1-(P-1)/i),P);
            for (j=0;j<up;j+=i)
                for (w=1LL,k=j;k<j+i/2;++k){
                    x=a[k]%P;y=w*a[k+i/2]%P;
                    a[k]=(x+y)%P;
                    a[k+i/2]=((x-y)%P+P)%P;
                    w=w*wn%P;
                }
        }if (f==-1) for (i=0;i<up;++i) a[i]=a[i]*nup%P;
    }
    void mul(LL *c,LL *a,LL *b){
        int i;
        for (i=0;i<up;++i) bi[i]=a[i];
        for (i=0;i<up;++i) ci[i]=b[i];
        ntt(bi,1),ntt(ci,1);
        for (i=0;i<up;++i) c[i]=bi[i]*ci[i]%P;
        for (ntt(c,-1),i=m-1;i<up;++i){
          c[i-m+1]=(c[i-m+1]+c[i])%P;c[i]=0LL;
        }
    }
    void pow(LL *a,int n){
        c[0]=1LL;
        while(n){
            if (n&1) mul(c,c,a);
            mul(a,a,a);
            n>>=1;}
    }
    int main(){
        int i,n,si,x;
        scanf("%d%d%d%d",&n,&m,&x,&si);
        for (i=1;i<=si;++i) scanf("%d",&s[i]);
        for (pre(),i=1;i<=si;++i){
            if (s[i]==0) continue;
            ++aa[po[s[i]]];
        }pow(aa,n);
        printf("%I64d
    ",c[po[x]]);
    }
    View Code

    bzoj4555 求和

    题目大意:第二类stirling数S(i,j)=j*S(i-1,j)+S(i-1,j-1)(边界S(i,i)=1,S(i,0)=0),求sigma(i=0~n,j=0~i)S(i,j)*(2^j)*(j!)。

    思路:stirling数有一个公式S(n,m)=1/(m!)*sigma(k=0~m)(-1)^k*C(m,k)*(m-k)^n,和题目中的式子暴力化简可以得到sigma(i=0~n,j=0~i)2^j*(j!)*sigma(k=0~j)(-1)^k/(k!)*(m-k)^n/((m-k)!),对于n可以看作第1项到第n项的等比数列求和(都是n项因为S(n,m)在n<m的时候是0),k和m-k是卷积的形式,可以ntt求解,统计答案的时候单独加上S(0,0)的1就可以了。

    关于公式的推导(!!!):先给所有集合编号,最后除以m!。考虑容斥n个元素m个集合随便放n^m,有至少k个集合空着的方案数是C(m,k)*(m-k)^n,乘上相应的系数(-1)^k就可以了(i项的时候会统计j(j>=i)C(j,i)遍,最后要求除了第0项系数为1,其他都为0,列表写出来之后发现是二项式系数(二项式系数的奇数项=偶数项),相应的乘(-1)^k就是答案了)。

    注意:1)求原根的时候是m-1的约数,ntt求wn的时候是(p-1)/i;

         2)递推的时候不要忘记%p。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define N 400005
    #define p 998244353LL
    #define G 3LL
    #define LL long long
    using namespace std;
    int rev[N],m,up,len;
    LL fac[N],inv[N],ai[N]={0},bi[N]={0},ci[N]={0},aa[N],nup;
    LL mi(LL x,LL y,LL pp){
        LL a=1LL;
        for (;y;y>>=1){
            if (y&1LL) a=a*x%pp;
            x=x*x%pp;
        }return a;}
    void ntt(LL *a,int f){
        int i,j,k;LL w,wn,x,y;
        nup=mi((LL)up,p-2LL,p);
        for (i=0;i<up;++i) aa[i]=a[i];
        for (i=0;i<up;++i) a[rev[i]]=aa[i];
        for (i=2;i<=up;i<<=1){
            wn=mi(G,(f==1 ? (p-1)/i : p-1-(p-1)/i),p);
            for (j=0;j<up;j+=i){
                w=1LL;
                for (k=j;k<j+i/2;++k){
                    x=a[k];y=w*a[k+i/2]%p;
                    a[k]=(x+y)%p;
                    a[k+i/2]=((x-y)%p+p)%p;
                    w=w*wn%p;
                }
            }
        }if (f==-1) for (i=0;i<up;++i) a[i]=a[i]*nup%p;
    }
    int main(){
        int n,i,j;LL ans=1LL;
        scanf("%d",&n);fac[0]=1LL;
        for (i=1;i<=n;++i) fac[i]=fac[i-1]*(LL)i%p;
        inv[n]=mi(fac[n],p-2LL,p);
        for (i=n-1;i>=0;--i) inv[i]=inv[i+1]*(LL)(i+1)%p;
        for (len=0,up=1;up<n;up<<=1,++len);up<<=1;++len;
        for (i=0;i<up;++i){
            for (j=i,ci[0]=0;j;j>>=1) ci[++ci[0]]=j&1;
            for (j=1;j<=len;++j) rev[i]=(rev[i]<<1)|ci[j];
        }bi[1]=(LL)n*inv[1]%p;
        for (i=0;i<=n;++i){
            ai[i]=((i&1) ? p-inv[i] : inv[i]);
            if (i>=2) bi[i]=(mi((LL)i,(LL)(n+1),p)+p-i)*mi((LL)(i-1),p-2LL,p)%p*inv[i]%p;
        }ntt(ai,1);ntt(bi,1);
        for (i=0;i<up;++i) ci[i]=ai[i]*bi[i]%p;
        ntt(ci,-1);
        for (i=1;i<=n;++i) ans=(ans+mi(2LL,(LL)i,p)*ci[i]%p*fac[i])%p;
        printf("%I64d
    ",ans);
    }
    View Code

    分治fft/ntt

    省队集训day3T2

    题目大意:求长度为n的排列的个数,满足任意前i个的最大值>后面的最小值。

    思路:相当于任意前i个都不是i的排列,设fi[i]表示i个数的答案,容斥一下,fi[i]=i!-sigma(j=1~i-1)(j!*fi[i-j]),可以通过分治ntt求解。类似cdq分治,每次用l~mid的值更新mid+1~r。

    对于rev数组可以O(n)求解:rev[i]=(rev[i>>1]>>1)|((i&1) ? (len>>1) : 0)。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define N 400005
    #define p 998244353LL
    #define LL long long
    #define G 3LL
    using namespace std;
    LL ai[N],bi[N],ci[N],fac[N],fi[N]={0},aa[N];
    int rev[N],cc[N];
    LL mi(LL x,LL y){
        LL a=1LL;
        for (;y;y>>=1LL){
            if (y&1LL) a=a*x%p;
            x=x*x%p;
        }return a;}
    void ntt(LL *a,int up,int f){
        int i,j,k;LL x,y,nup,w,wn;
        for (i=0;i<up;++i) aa[i]=a[i];
        for (i=0;i<up;++i) a[rev[i]]=aa[i];
        nup=mi(up,p-2LL);
        for (i=2;i<=up;i<<=1){
            wn=mi(G,(f==1 ? (p-1)/i : p-1-(p-1)/i));
            for (j=0;j<up;j+=i){
                w=1LL;
                for (k=j;k<j+i/2;++k){
                    x=a[k];y=a[k+i/2]*w%p;
                    a[k]=(x+y)%p;
                    a[k+i/2]=(x+p-y)%p;
                    w=w*wn%p;
                }
            }
        }if (f==-1) for (i=0;i<up;++i) a[i]=a[i]*nup%p;
    }
    void solve(int l,int r){
        if (l==r){fi[l]=(fac[l]+p-fi[l])%p;return;}
        int i,j,mid,len,up;
        mid=(l+r)>>1;solve(l,mid);
        for (len=0,up=1;up<(r-l+1);up<<=1,++len);up<<=1;++len;
        for (i=1;i<=len;++i) cc[i]=0;
        for (i=0;i<up;++i){
            rev[i]=0;
            for (cc[0]=0,j=i;j;j>>=1) cc[++cc[0]]=j&1;
            for (j=1;j<=len;++j) rev[i]=(rev[i]<<1)|cc[j];
        }for (i=0;i<up;++i){ai[i]=0LL,bi[i]=fac[i];}
        for (i=l;i<=mid;++i) ai[i-l]=fi[i];
        ntt(ai,up,1);ntt(bi,up,1);
        for (i=0;i<up;++i) ci[i]=ai[i]*bi[i]%p;
        ntt(ci,up,-1);
        for (i=mid+1;i<=r;++i) fi[i]=(fi[i]+ci[i-l])%p;
        solve(mid+1,r);
    }
    void pre(int n){
        int i;
        for (fac[0]=1LL,i=1;i<N;++i) fac[i]=fac[i-1]*i%p;
        solve(1,n);
    }
    int main(){
        freopen("sequence.in","r",stdin);
        freopen("sequence.out","w",stdout);
        
        int t,n;scanf("%d",&t);
        pre(100000);
        while(t--){
            scanf("%d",&n);
            if (n==2000000) printf("280765512
    ");
            else printf("%I64d
    ",fi[n]);
        }
    }
    View Code

    相关算法

    bzoj4589 Hard Nim(!!!

    题目大意:已知n堆石子,每堆的个数是m以内的质数,问后手必胜的方案数。

    思路:设ai=i,可以写成(sigma(i=0~m)(bi*ai))^n,其中bi是系数,bi=1当且仅当i<=m&&i是质数,答案就是最后a0的系数。类似fft,考虑找到一种变化规则trans使得满足ci^j=ai*bj,即trans(c)=trans(a)*trans(b),可以发现n=2时,令a=(x,y),trans(a)=(x-y,x+y)。推广下去的话,a=(a1,a2),trans(a)=(trans(a1)-trans(a2),trans(a1)+trans(a2)),这就是fwt转化回来的时候逆操作,j=i+n/2,ai'=ai-aj,aj'=ai+aj;ai=(ai'+aj')/2,aj=(aj'-ai')/2。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define N 100005
    #define p 1000000007
    #define LL long long
    using namespace std;
    int prime[N]={0},flag[N]={0},n,m,ai[N],inv;
    void shai(){
        int i,j;
        for (i=2;i<N;++i){
            if (!flag[i]) prime[++prime[0]]=i;
            for (j=1;j<=prime[0]&&i*prime[j]<N;++j){
                flag[i*prime[j]]=true;
                if (i%prime[j]==0) break;
            }
        }
    }
    int mi(int x,int y){
        int a=1;
        for (;y;y>>=1){
            if (y&1) a=(LL)a*x%p;
            x=(LL)x*x%p;
        }return a;}
    void solve(int up){
        int i,j,k,x,y;
        for (i=2;i<=up;i<<=1)
            for (j=0;j<up;j+=i)
                for (k=j;k<j+i/2;++k){
                    x=ai[k];y=ai[k+i/2];
                    ai[k]=(x+p-y)%p;
                    ai[k+i/2]=(x+y)%p;
                }
    }
    void nsol(int up){
        int i,j,k,x,y;
        for (i=up;i>=2;i>>=1)
            for (j=0;j<up;j+=i)
                for (k=j;k<j+i/2;++k){
                    x=ai[k];y=ai[k+i/2];
                    ai[k]=(LL)(x+y)*inv%p;
                    ai[k+i/2]=(LL)(y+p-x)*inv%p;
                }
    }
    int work(){
        int i,up;inv=mi(2,p-2);
        for (up=1;up<=m;up<<=1);
        memset(ai,0,sizeof(ai));
        for (i=1;i<=prime[0]&&prime[i]<=m;++i) ai[prime[i]]=1;
        solve(up);
        for (i=0;i<up;++i) ai[i]=mi(ai[i],n);
        nsol(up);
        return ai[0];}
    int main(){
        shai();
        while(scanf("%d%d",&n,&m)==2)
            printf("%d
    ",work());
    }
    View Code
  • 相关阅读:
    某些电脑前面板没声音问题
    安装win10笔记
    linux 时区问题
    JS实现网页飘窗
    缓存promise技术不错哦
    wepy相关
    生成keystore
    2017年终巨献阿里、腾讯最新Java程序员面试题,准备好进BAT了吗
    细思极恐-你真的会写java吗
    年终盘点:Java今年的大事记都在这里!
  • 原文地址:https://www.cnblogs.com/Rivendell/p/5100137.html
Copyright © 2011-2022 走看看