zoukankan      html  css  js  c++  java
  • FFT,NTT

    注意:这是一篇个人学习笔记,如果有人因为某些原因点了进来并且要看一下,请一定谨慎地阅读,因为可能存在各种奇怪的错误,如果有人发现错误请指出谢谢!


    20180303

    https://www.cnblogs.com/rvalue/p/7351400.html (资料1)

    http://blog.csdn.net/ACdreamers/article/details/39005227

    http://blog.csdn.net/leo_h1104/article/details/51615710

    http://blog.csdn.net/tt2767/article/details/47301849

    http://blog.csdn.net/Tag_king/article/details/46351821


    20190118

    复数

    形如$a+bi$,其中$i^2=-1$。有几何意义:复平面上一个向量(a,b)

    乘法:$(a + bi) (c + di) = ac + bci + adi + bd i^2 = (ac - bd) + (bc + ad)i$

    乘法几何意义:幅角(对应向量与x轴正方向有向夹角,设x轴是实轴)相加, 模(对应向量长度)相乘

    欧拉公式(暂时不研究原因,证明参考欧拉公式):$e^{ix} = cos x+isin x$

    注意:复数作为底数时不一定满足指数运算律!应当把复数用欧拉公式表示后运算

    (复数作为$(xy)^z = x^zy^z$中的x,y, $(x^y)^z = x^{yz}$中的x, $x^yx^z = x^{y+z}$中的x时,这些式子不一定成立)

    (仅非0复数的整数幂函数满足指数运算律,当指数为分数时结果是”多值“的,参考http://blog.sciencenet.cn/blog-826653-900633.html

    单位复根

    n次单位复根指满足$omega ^n=1$的复数$omega$。

    可以得到,它们共有n个,可以表示为$e^{ix*2pi /n},x=0,1,2,..,(n-1)$(具体参考资料1)

    设$omega _n=e^{i*2pi/n}$,可以得到$x=k$时的那个n次单位复根为$(omega_n)^k$(可以写成$omega_n^k$)

    引理

    1.引理(消去引理)

    对任意整数$n geq 0$, $kgeq 0$, 以及$dgeq 0$, 有$omega_{dn}^{dk}=omega_n^k$

    证明:由定义易得

    2.引理(??)

    $omega _n^m=-omega _n^{m+frac{n}{2}}$

    证明:$omega _n^{m+frac{n}{2}}=e^{2pi i/n*(m+frac{n}{2})}=-e^{2pi im/n}=-omega _n^m$

    3.引理(求和引理)

    对于任意整数n>=1与不能被n整除的非负整数k,有$sum_{j=0}^{n-1}(omega_n^k)^j=0$

    证明:等比数列求和

    (以下均假设n已经被补成2的幂)

    DFT

    就是要求出多项式的点值表示法(对于x=$omega _n^k$,k=0,1,2,..,n-1求出多项式的值,具体定义参见资料)

    用分治实现。对于多项式$A(x) = sum _{i=0} ^{n-1} a_i x^i$,拆分成$A_0(x)=a_0+a_2x+a_4x^2+...+a_{n-2}x^{n/2-1}$和$A_1(x)=a_1+a_3x+a_5x^2+...+a_{n-1}x^{n/2-1}$。

    那么$A(x)=A_0(x^2)+A_1(x^2)x$

    当k<n/2时,$A(omega _n^k)=A_0(omega _{n/2}^k)+A_1(omega _{n/2}^k)omega _n^k$,$A(omega _n^{k+n/2})=A(-omega _n^k)=A_0(omega _{n/2}^k)-A_1(omega _{n/2}^k)omega _n^k$

    那么,只要得到$A_0(x)$和$A_1(x)$的点值表示法,就可以推出$A(x)$的点值表示法,就可以分治了

    资料:https://www.cnblogs.com/RabbitHu/p/FFT.html

    IDFT

    从系数表示法回到点值表示法。

    直接将A(x)的点值表示法的点值当成DFT时候的系数,DFT里面的x改成x=$omega _n^{-k}$,k=0,1,2,..,n-1,最后结果各个项再除以n,得到的就是A(x)

    原因:结果中的第k项=$sum _{i=0}^{n-1}(sum _{j=0}^{n-1}a_j omega_n^{ij})omega_n^{-ik}=sum _{j=0}^{n-1}a_jsum_{i=0}^{n-1}(omega_n^{j-k})^i$$=sum_{j=0}^{n-1}a_j*[n|j-k]*n=a_k*n$

    FFT实现

    具体实现有一些方法,例如非递归fft,参见资料

    模板

    loj108  洛谷P3803  uoj34

    提醒:直接%.0f输出非常慢;在这里可以加0.5再取整后输出;多项式卷积的长度可能达到两个多项式长度之和级别

    来自毛啸2016论文:

     1 #include<cstdio>
     2 #include<algorithm>
     3 #include<cstring>
     4 #include<vector>
     5 #include<cmath>
     6 using namespace std;
     7 #define fi first
     8 #define se second
     9 #define mp make_pair
    10 #define pb push_back
    11 typedef long long ll;
    12 typedef unsigned long long ull;
    13 typedef pair<int,int> pii;
    14 struct cpl
    15 {
    16     double x,y;
    17     cpl(double x=0,double y=0):x(x),y(y){}
    18 };
    19 cpl operator+(const cpl &a,const cpl &b)
    20 {
    21     return (cpl){a.x+b.x,a.y+b.y};
    22 }
    23 cpl operator-(const cpl &a,const cpl &b)
    24 {
    25     return (cpl){a.x-b.x,a.y-b.y};
    26 }
    27 cpl operator*(const cpl &a,const cpl &b)
    28 {
    29     return (cpl){a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y};
    30 }
    31 const double pi=acos(-1);
    32 const int N=2097152;
    33 int rev[N];
    34 void init(int len)
    35 {
    36     int bit=0,i;
    37     while((1<<(bit+1))<=len)    ++bit;
    38     for(i=0;i<len;++i)
    39         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    40 }
    41 void dft(cpl *a,int len,int idx)//要求len为2的幂
    42 {
    43     int i,j,k;cpl t1,t2,wn,wnk;
    44     for(i=0;i<len;++i)
    45         if(i<rev[i])
    46             swap(a[i],a[rev[i]]);
    47     for(i=1;i<len;i<<=1)
    48     {
    49         wn=cpl(cos(pi/i),idx*sin(pi/i));
    50         //合并[j,j+i)和[j+i,j+2i),注意wn的n是2i而不是i
    51         for(j=0;j<len;j+=(i<<1))
    52         {
    53             wnk=cpl(1,0);
    54             for(k=j;k<j+i;++k,wnk=wnk*wn)
    55             {
    56                 t1=a[k];t2=a[k+i]*wnk;
    57                 a[k]=t1+t2;a[k+i]=t1-t2;
    58             }
    59         }
    60     }
    61     if(idx==-1)
    62     {
    63         for(i=0;i<len;++i)
    64             a[i].x/=len,a[i].y/=len;
    65     }
    66 }
    67 cpl a[N],b[N];
    68 int n,m;
    69 int main()
    70 {
    71     int i;
    72     init(N);
    73     scanf("%d%d",&n,&m);
    74     for(i=0;i<=n;++i)
    75         scanf("%lf",&a[i].x);
    76     for(i=0;i<=m;++i)
    77         scanf("%lf",&b[i].x);
    78     dft(a,N,1);dft(b,N,1);
    79     for(i=0;i<N;++i)
    80         a[i]=a[i]*b[i];
    81     dft(a,N,-1);
    82     for(i=0;i<=n+m;++i)
    83         printf("%d ",int(a[i].x+0.5));
    84     return 0;
    85 }
    View Code

    NTT

    原根:有数a,p,如果a是p的一个原根,那么$a^0\,mod\,p,a^1\,mod\,p,..,a^{p-2}\,mod\,p$刚好是1,2,3,..,p-1各出现一次(这里不需要更深入了解原因)

    现在有质数$p=r*2^k+1$以及它的一个原根a

    (以下默认n是2的幂)(NTT部分的运算全部默认为模p意义下运算)

    那么考虑令$omega_n=a^{frac{p-1}{n}}$,当然,可以发现这里要求$n<=2^k$

    可以发现以上引理1,3显然仍然成立

    对于引理2,根据一些奇奇怪怪的理论可以知道$a^{frac{p-1}{2}}=-1$

    奇奇怪怪的理论(不懂):

    https://math.stackexchange.com/questions/353741/how-to-derive-this-expression-r-p-1-2-equiv-1-pmod-p-for-primitive-r/353747

    https://blog.csdn.net/feynman1999/article/details/82117243

    那么$omega_n^{n/2}=a^{frac{p-1}{2}}=-1$,显然引理2也仍然成立

    其他跟FFT基本是一样的...无非就是除法变成乘逆元

    NTT质数表:转自http://blog.miskcoo.com/2014/07/fft-prime-table

    $r*2^k+1$ r k 原根g
    469762049 7 26 3
    998244353 119 23 3
    1004535809 479 21 3

    再来个模板,仍然是上面那道题

    普通版本

    卡常记录:洛谷上测的时候,49和50行换一下会慢接近一倍

     1 #prag
     2 ma GCC optimize(2)
     3 #include<cstdio>
     4 #include<algorithm>
     5 #include<cstring>
     6 #include<vector>
     7 #include<cmath>
     8 using namespace std;
     9 #define fi first
    10 #define se second
    11 #define mp make_pair
    12 #define pb push_back
    13 typedef long long ll;
    14 typedef unsigned long long ull;
    15 const int md=998244353;
    16 const int N=2097152;
    17 int rev[N];
    18 void init(int len)
    19 {
    20     int bit=0,i;
    21     while((1<<(bit+1))<=len)    ++bit;
    22     for(i=0;i<len;++i)
    23         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    24 }
    25 ll poww(ll a,ll b)
    26 {
    27     ll base=a,ans=1;
    28     for(;b;b>>=1,base=base*base%md)
    29         if(b&1)
    30             ans=ans*base%md;
    31     return ans;
    32 }
    33 void dft(int *a,int len,int idx)//要求len为2的幂
    34 {
    35     int i,j,k,t1,t2;ll wn,wnk;
    36     for(i=0;i<len;++i)
    37         if(i<rev[i])
    38             swap(a[i],a[rev[i]]);
    39     for(i=1;i<len;i<<=1)
    40     {
    41         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
    42         for(j=0;j<len;j+=(i<<1))
    43         {
    44             wnk=1;
    45             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
    46             {
    47                 t1=a[k];t2=a[k+i]*wnk%md;
    48                 a[k]+=t2;
    49                 (a[k]>=md) && (a[k]-=md);
    50                 a[k+i]=t1-t2;
    51                 (a[k+i]<0) && (a[k+i]+=md);
    52             }
    53         }
    54     }
    55     if(idx==-1)
    56     {
    57         ll ilen=poww(len,md-2);
    58         for(i=0;i<len;++i)
    59             a[i]=a[i]*ilen%md;
    60     }
    61 }
    62 int a[N],b[N];
    63 int n,m;
    64 int main()
    65 {
    66     int i;
    67     init(N);
    68     scanf("%d%d",&n,&m);
    69     for(i=0;i<=n;++i)
    70         scanf("%d",&a[i]);
    71     for(i=0;i<=m;++i)
    72         scanf("%d",&b[i]);
    73     dft(a,N,1);dft(b,N,1);
    74     for(i=0;i<N;++i)
    75         a[i]=ll(a[i])*b[i]%md;
    76     dft(a,N,-1);
    77     for(i=0;i<=n+m;++i)
    78         printf("%d ",a[i]);
    79     return 0;
    80 }
    View Code

    预处理版本(慢,暂时先废弃着)(实测并没有快多少,不知道为什么)

     1 #include<cstdio>
     2 #include<algorithm>
     3 #include<cstring>
     4 #include<vector>
     5 #include<cmath>
     6 using namespace std;
     7 #define fi first
     8 #define se second
     9 #define mp make_pair
    10 #define pb push_back
    11 typedef long long ll;
    12 typedef unsigned long long ull;
    13 const ll md=998244353;
    14 const int bit=21,N=2097152,invN=998243877;
    15 int rev[N];
    16 ll wnk[21][N],iwnk[21][N];
    17 //wnk[i][j]:w_{2^(i+1)}^j;iwnk[i][j]:inv(wnk[i][j])
    18 //w_{2^(i+1)}=3^{(md-1)/(2^(i+1))}
    19 ll poww(ll a,ll b)
    20 {
    21     ll base=a,ans=1;
    22     for(;b;b>>=1,base=base*base%md)
    23         if(b&1)
    24             ans=ans*base%md;
    25     return ans;
    26 }
    27 void init(int len)
    28 {
    29     int i,j,ed;ll wn;
    30     for(i=0;i<len;++i)
    31         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    32     for(i=0;i<bit;++i)
    33     {
    34         ed=1<<(i+1);
    35         wnk[i][0]=iwnk[i][0]=1;
    36         wn=wnk[i][1]=poww(3,(md-1)/(1<<(i+1)));
    37         for(j=2;j<ed;++j)
    38             wnk[i][j]=wnk[i][j-1]*wn%md;
    39         wn=iwnk[i][1]=poww(332748118,(md-1)/(1<<(i+1)));
    40         for(j=2;j<ed;++j)
    41             iwnk[i][j]=iwnk[i][j-1]*wn%md;
    42     }
    43 }
    44 void dft(ll *a,int len,int idx)//要求len为2的幂
    45 {
    46     int i,i1,j,k;ll t1,t2;
    47     for(i=0;i<len;++i)
    48         if(i<rev[i])
    49             swap(a[i],a[rev[i]]);
    50     for(i=1,i1=0;i<len;i<<=1,++i1)
    51     {
    52         const ll *wnk=idx==1?(::wnk[i1]):iwnk[i1];
    53         for(j=0;j<len;j+=(i<<1))
    54         {
    55             for(k=j;k<j+i;++k)
    56             {
    57                 t1=a[k];t2=a[k+i]*wnk[k-j]%md;
    58                 a[k]+=t2;a[k+i]=t1-t2;
    59                 (a[k]>=md) && (a[k]-=md);
    60                 (a[k+i]<0) && (a[k+i]+=md);
    61             }
    62         }
    63     }
    64     if(idx==-1)
    65     {
    66         for(i=0;i<len;++i)
    67             (a[i]*=invN)%=md;
    68     }
    69 }
    70 ll a[N],b[N];
    71 int n,m;
    72 int main()
    73 {
    74     int i;
    75     init(N);
    76     scanf("%d%d",&n,&m);
    77     for(i=0;i<=n;++i)
    78         scanf("%lld",&a[i]);
    79     for(i=0;i<=m;++i)
    80         scanf("%lld",&b[i]);
    81     dft(a,N,1);dft(b,N,1);
    82     for(i=0;i<N;++i)
    83         a[i]=a[i]*b[i]%md;
    84     dft(a,N,-1);
    85     for(i=0;i<=n+m;++i)
    86         printf("%lld ",a[i]);
    87     return 0;
    88 }
    View Code

    todo

    矩阵:范德蒙矩阵?行列式?可逆条件?“奇异”?

    拉格朗日插值?

  • 相关阅读:
    Spark ML 文本的分类
    Linxu 安装Scala
    Nginx访问非常慢
    mysql:unknown variable 'default-character-set=utf8'
    mysql 1045
    mysql: error while loading shared libraries: libnuma.so
    elasticsearch: can not run elasticsearch as root
    Java中的Class类
    ClassLoader工作机制
    遍历机器查日志
  • 原文地址:https://www.cnblogs.com/hehe54321/p/8503294.html
Copyright © 2011-2022 走看看