zoukankan      html  css  js  c++  java
  • fft,ntt

    在被两题卡了常数之后,花了很久优化了自己的模板

    现在的一般来说任意模数求逆1s跑3e5,exp跑1e5是没啥问题的(自己电脑,可能比luogu慢一倍)

    当模数是$998244353,1004535809,9985661441$的时候(这$3$个的原根都是$3$)

    我们会使用$ntt$来求解

    $ntt$的模板本身常数不大 优化效果不明显

    const int mo=998244353;
    const int G=3;
    IL int fsp(int x,int y)
    {
        ll now=1;
        while (y)
        {
            if (y&1) now=now*x%mo;
            x=1ll*x*x%mo;
            y>>=1;
        }
        return now;
    }
    IL void ntt_init()
    {
        l=0; for (n=1;n<=m;n<<=1) l++;
        for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
    }
    IL void clear()
    {
        for (int i=0;i<=n;i++) a[i]=b[i]=0;
    }
    void ntt(int *a,int o)
    {
        for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
        for (int i=1;i<n;i<<=1)
        {
            int wn=fsp(G,(mo-1)/(i*2)); w[0]=1;
            rep(j,1,i-1) w[j]=(1ll*w[j-1]*wn)%mo;
            for (int j=0;j<n;j+=(i*2))
              for (int k=0;k<i;k++)
              {
                  int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo;
                  a[j+k]=(x+y)%mo; a[i+j+k]=(x-y)%mo;
              }
        }
        if (o==-1)
        {
          reverse(&a[1],&a[n]);
          for (int i=0,inv=fsp(n,mo-2);i<n;i++)
            a[i]=1ll*a[i]*inv%mo;
        }
    }
    IL void getcj(int *A,int *B,int len)
    {
        m=len*2; ntt_init();
        for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
        ntt(a,1); ntt(b,1);
        for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
        ntt(a,-1);
        for (int i=0;i<len;i++) B[i]=a[i];
        clear();
    }
     

    当模数不为这$3$个,我们就需要$mtt$来实现

    而$mtt$的实现为用$mx$的方法将数的实部虚部分别放$x & 65536,x(>>15)$

    另外一个重要的地方是要预处理出$w$,我们采用指针来存,避免使用vector

    代码$p$的初始值为$2*n$

    所有数组大小为$4*n$

    $getcj$的时候要先把数组中的负数变正

    IL void clear()
    {
        for (int i=0;i<=n;i++) a[i].a=a[i].b=b[i].a=b[i].b=c[i].a=c[i].b=d[i].a=d[i].b=0;
    }
    cp *w[N],tmp[N*2];
    int p;
    IL void init()
    {
        cp *now=tmp;
        for (int i=1;i<=p;i<<=1)
        {
            w[i]=now;
            for (int j=0;j<i;j++) w[i][j]=(cp){cos(pi*j/i),sin(pi*j/i)};
            now+=i;
        }
    }
    IL void fft_init()
    {
        l=0; for (n=1;n<=m;n<<=1) l++;
        for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
    }
    void fft(cp *a,int o)
    {
        for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
        for (int i=1;i<n;i<<=1)
            for (int j=0;j<n;j+=(i*2))
            {
              cp *x1=a+j,*x2=a+i+j,*W=w[i];
              for (int k=0;k<i;k++,x1++,x2++,W++)
              {
                  cp x=*x1,y=(cp){(*W).a,(*W).b*o}*(*x2); 
                  *x1=x+y,*x2=x-y;
              }
            }
        if (o==-1) for(int i=0;i<n;i++) a[i].a/=n;
    }
    IL void getcj(int *A,int *B,int len)
    {
        rep(i,0,len)
        {
            A[i]=(A[i]+mo)%mo,B[i]=(B[i]+mo)%mo;
        }
        for (int i=0;i<len;i++)
        {
           a[i]=(cp){A[i]&32767,A[i]>>15};
           b[i]=(cp){B[i]&32767,B[i]>>15};
        }
        m=len*2; fft_init();
        fft(a,1); fft(b,1);
        for (int i=0;i<n;i++)
        {
            int j=(n-1)&(n-i);
            c[j]=(cp){0.5*(a[i].a+a[j].a),0.5*(a[i].b-a[j].b)}*b[i];
            d[j]=(cp){0.5*(a[i].b+a[j].b),0.5*(a[j].a-a[i].a)}*b[i];
        }
        fft(c,1); fft(d,1);
        double inv=ee/n;
        rep(i,0,n) c[i].a*=inv,c[i].b*=inv;
        rep(i,0,n) d[i].a*=inv,d[i].b*=inv;
        rep(i,0,len)
        {
            ll a1=c[i].a+0.5,a2=c[i].b+0.5;
            ll a3=d[i].a+0.5,a4=d[i].b+0.5;
            B[i]=(a1+((a2+a3)%mo<<15)+((a4%mo)<<30))%mo;
        }
        clear();
    }

    对于其他的多项式函数

    用$fft$还是$ntt$是差不多的(除了数组类型)

  • 相关阅读:
    【前端】原生event对象和jquery event对象的区别
    【前端】js代码模拟用户键盘鼠标输入
    【前端】回到顶部
    【前端】Three.js
    【前端】三种复制数组的方法
    【Python】Django
    【前端】CommonJS的模块加载机制
    注释声明:TODO HACK XXX FIXME REVIEW
    【Python】Python3中的str和bytes
    【前端】iterable类型的 forEach方法
  • 原文地址:https://www.cnblogs.com/yinwuxiao/p/9417115.html
Copyright © 2011-2022 走看看