zoukankan      html  css  js  c++  java
  • 「学习笔记」FFT快速傅里叶变换

    基本信息

    用途 : 多项式乘法

    时间复杂度 : (O(nlogn)) (常数略大)

    算法过程

    基本思路

    (H(x) = G(x) imes F(x))

    直接从系数表达式转化为系数表达式比较难搞, 所以考虑先把 (F(x), G(x)) 转化为点值表达式, 再 (O(n)) 求出 (H(x)) 的点值表达式, 然后从 (H(x)) 的点值表达式转化为 (H(x)) 的系数表达式.

    其中, 从系数表达式转化为点集表达式的过程叫 (DFT), 又叫 求值运算.

    从系数表达式转化为点集表达式的过程叫 (IDFT), 又叫 插值运算.

    求值运算

    先考虑求值运算的过程, 设 (F(x),G(x)) 分别为 (n) 次, (m) 次的多项式, 则 (H(x))(n+m) 次的多项式,

    所以我们需要求出 (F(x),G(x))(n+m-1) 个不同的点处的值, 才能保证最终求得的 (H(x)) 的唯一性, (可以类比求函数解析式所需的条件).

    如果直接硬算, 复杂度会达到 (O(n^2)), 所以我们需要借助一个叫做单位根的神奇东西.


    复数

    引入单位根之前, 得先介绍一下复数.

    首先, 我们定义一个数 (i), 使 (i^2=-1) (下文中的所有 (i) 都表示这个东西).

    形如 (a+bi) 的数就叫做复数, 其中 (a,b in mathbb{R}).

    复数和实数一样, 也有四则运算 (其实可以类比成多项式的运算).

    (x = a+bi, y=c+di), 则

    1. $ x+y = (a+c)+(b+d)i $
    2. $ x-y = (a-c)+(b-d)i $
    3. $ x imes y = (ac-bd)+(ad+cb)i$ (把 (x,y) 当成多项式乘开即可).
    4. $ frac{x}{y} = frac{a+bi}{c+di} = frac{(a+bi)(c-di)}{(c+di)(c-di)} = frac{(ac+bd)+(ad+cb)i}{c2+d2} $ (类似于无理数运算中分母有理化的过程).

    接下来, 我们介绍一个叫 "复平面" 的东西.

    长这样

    img

    和数轴上的一个点能唯一地表示一个实数类似, 复平面上的一个点能唯一地表示一个复数.

    其中, (x) 轴上的数为实数 ((real axis)), (y) 轴上的数为虚数 ((imaginary axis)).

    我们设一个复数的辐角为该复数在复平面上的点对应的向量(x) 轴逆时针的夹角,

    一个复数的模长为该复数对应向量的模长.

    我们会得到一个神奇的性质 :

    (x,y,z) 都为复数, 且 (x imes y = z), 则 (z) 的幅角等于 (x,y)幅角相加, (z) 的模长等于 (x,y) 的模长相乘.

    如下图 (图源)

    img

    幅角相加可以用三角函数证明, 模长相乘可以把坐标带入直接算就好. (证明过程写出来比较麻烦, 原谅我时间有限)

    单位根

    有了上面的基础后, 我们就可以来认识单位根了.

    定义 : 若复数 (x^n = 1, ( n in mathbb{N+})), 则称 (x)(n) 次单位根.

    考虑一下复数相乘的性质, 可以发现, (x) 的模长必然为 (1), (大于 (1) 的话会越乘越大, 小于 (1) 的话会越乘越小),

    (x) 的幅角为 (frac{2pi k}{n}, (k in [0,n) )).

    那也就意味着, (x) 一定在复平面的单位圆上, 并且将单位圆 (n) 等分.

    3次单位根

    为了便于称呼, 我们用 (omega_n) 来表示 (n) 单位根, 并从 (1) 开始将他们逐个编上号, (omega_n^0 = 1).

    接下来, 我们介绍一些单位根的性质 (原谅我真的没时间....)

    1. (omega_n^k = (omega_n^1)^k)
    2. $omega_n^0 omega_n^1 dots omega_n^{n-1} $ 互不相等.
    3. (omega_n^{k+frac{n}{2}} = -omega_n^k) ((n) 为偶数)
    4. (omega_{2n}^{2k} = omega_n^k)
    5. (sum_{k=0}^{n-1} omega_n^k = 0) (带入等差数列求和公式即可)

    好了, 复数和单位根就介绍到这里, 还记得我们原来要干什么吗?

    我们想把 (F(x))系数表达式 转化为 点值表达式 .

    求点值表达式, 就需要选择 (n+m-1) 个自变量 (x) 带入求值.

    通常情况下, 这个操作的复杂度是 (O(n^2)) 级别的, 但我们的傅里叶大大发现, 把单位根带入求值, 会有神奇的效果.

    为了方便描述, 我们这里把 (n) 重定义为大于 (n+m-1) 的第一个 (2) 的正整数次方, 并把 (F(x)) 重定义为 (n-1) 次多项式, 后面多出的系数默认为 (0).

    (omega_n^k) ($ k in [0,frac{n}{2})$)带入 (F(x)), 得到

    [F(omega_n^k) = f[0]omega_n^0 + f[1]omega_n^k + dots + f[n-1]omega_n^{(n-1)k} ]

    尝试使用分值的思想, 把奇偶次项分开, 得到

    [F(omega_n^k) = f[0]omega_n^0 + f[2]omega_n^{2k} + dots + f[n-2]omega_n^{(n-2)k} + f[1]omega_n^k + f[3]omega_n^{3k} + dots + f[n-1]omega_n^{(n-1)k} ]

    两部分似乎有相似之处,

    (G1(x) = f[0]x^0 + f[2]x^1 + f[n-2]x^{frac{n}{2}-1})

    (G2(x) = f[1]x^0 + f[1]x^1 + f[n-1]x^{frac{n}{2}-1})

    [egin{aligned} F(omega_n^k) & = G1(omega_n^{2k}) + omega_n^kG2(omega_n^{2k}) \ & = G1(omega_{frac{n}{2}}^{k}) + omega_n^kG2(omega_{frac{n}{2}}^{k}) end{aligned} ]

    若再把 (omega_n^{k+frac{n}{2}}) 带入 (F(x)), 由于 (omega_n^{k+frac{n}{2}} = -omega_n^k), 所以他们的偶次项是相同的, 而奇次项是相反的.

    也就是

    [egin{aligned} F(omega_n^{k+frac{n}{2}}) & = G1(omega_n^{2k + n}) + omega_n^{k+frac{n}{2}}G2(omega_n^{2k + n}) \ & = G1(omega_{frac{n}{2}}^{k}) - omega_n^kG2(omega_{frac{n}{2}}^{k}) end{aligned} ]

    发现 (F(omega_n^k))(F(omega_n^{k+frac{n}{2}})) 化简后得到的式子只有一个符号的差别, 那么意味着, 我们只需算出当 (k in [0,frac{n}{2})) 时的

    [G1(omega_{frac{n}{2}}^{k}) ]

    [G2(omega_{frac{n}{2}}^{k}) ]

    这两个式子, 就可以算出 (omega_n^0)(omega_n^{n-1}) 的所有点值.

    而上面那两个式子显然 (应该显然吧...) 是可以递归处理的, 那么每次就减少计算一半的点, 时间复杂度就降低到了 (O(nlog n)).

    放个代码

    void trans(cn *f,int len,bool id){
      if(len==1) return;
      cn *g1=f,*g2=f+len/2;   // 直接在 f 数组的地址上修改, 防止使用内存过多
      for(int i=0;i<len;i++) tmp[i]=f[i];  // 由于是之间在 f 数组的地址上修改, 所以要备份 
      for(int i=0;2*i<len;i++){ g1[i]=tmp[i<<1]; g2[i]=tmp[i<<1|1]; }
      trans(g1,len/2,id);	// 递归处理
      trans(g2,len/2,id);
      cn w1=(cn){cos(2*Pi/len),sin(2*Pi/len)},wi=(cn){1,0};
      if(id) w1.b*=-1;
      for(int i=0;2*i<len;i++){
        tmp[i]=g1[i]+wi*g2[i];			// 上面的两个式子
        tmp[i+len/2]=g1[i]-wi*g2[i];
        wi=wi*w1;	// 处理出每个单位根
      }
      for(int i=0;i<len;i++) f[i]=tmp[i];
    }
    

    那么求值运算, 也就是 (DFT) 就大功告成了.


    差值运算

    我们先用矩阵乘法来表示一下求点值的过程.

    设 矩阵(A) 为要带入的 (n)自变量以及它们的 (0 sim n) 次方,

    矩阵 (B)(F(x))系数,

    矩阵 (C) 为自变量对应的 (n)点值.

    则有

    [AB = C ]

    image-20191229195524817

    现在我们知道了 (A), 知道了 (C), 要求 (B), 那一般思路就是把 (A) 除过去, 即

    [B = CA^{-1} ]

    其中 (A^{-1})(A)逆矩阵, 它们的乘积为单位矩阵.

    经过一系列复杂的运算后, 发现 (A^{-1}) 是长这样的, (可以尝试自己手推一下, 需要用到上面单位根的第 4 个性质)

    image-20191229195821568

    是不是很眼熟,

    没错, 实际上就是把 (A)(omega_n^k) 全都换成了 (omega_n^{-k}), 并在前面加了个系数.

    (CA^{-1}) 究竟要怎么算呢?

    是不是完全没有头绪? (还是只有我一个人是这样)

    答案是, 把 (A^{-1}) 看做 (A), 把 (C) 看做 (B), 把 (B) 看做 (C) , 再进行一遍 (DFT) 就行了. (说人话).

    就是 把点值看做一个新函数的系数, 然后把 (omega_n^0 sim omega_n^{-(n-1)}) 带入这个新函数, 求值, 得到的点值再乘上一个 (frac{1}{n}) 就得到了(H(x)), 也就是 (F(x) imes G(x)) 的系数.


    ok, 到此为止, 我们搞定了 (DFT)(IDFT) ,(FFT) 的流程也就到这里了,

    放代码.

    #include<bits/stdc++.h>
    #define _USE_MATH_DEFINES
    using namespace std;
    const int N=3e6+7; 
    const double Pi=M_PI;
    struct cn{
      double a,b;
      cn operator + (const cn &x) const{
        return (cn){x.a+a,x.b+b};
      }
      cn operator - (const cn &x) const{
        return (cn){a-x.a,b-x.b};
      }
      cn operator * (const cn &x) const{
        return (cn){x.a*a-x.b*b,x.a*b+a*x.b};
      }
      cn operator *= (const cn &x) const{
        return (cn){x.a*a-x.b*b,x.a*b+a*x.b};
      }
    };
    int n,m;
    cn f[N],g[N],tmp[N];
    void trans(cn *f,int len,bool id){
      if(len==1) return;
      cn *g1=f,*g2=f+len/2;   // 直接在 f 数组的地址上修改, 防止使用内存过多
      for(int i=0;i<len;i++) tmp[i]=f[i];  // 由于是之间在 f 数组的地址上修改, 所以要备份 
      for(int i=0;2*i<len;i++){ g1[i]=tmp[i<<1]; g2[i]=tmp[i<<1|1]; }
      trans(g1,len/2,id);	// 递归处理
      trans(g2,len/2,id);
      cn w1=(cn){cos(2*Pi/len),sin(2*Pi/len)},wi=(cn){1,0};
      if(id) w1.b*=-1;
      for(int i=0;2*i<len;i++){
        tmp[i]=g1[i]+wi*g2[i];			// 上面的两个式子
        tmp[i+len/2]=g1[i]-wi*g2[i];
        wi=wi*w1;	// 处理出每个单位根
      }
      for(int i=0;i<len;i++) f[i]=tmp[i];
    }
    int main(){
      //  freopen("FFT.in","r",stdin);
      cin>>n>>m;
      for(int i=0;i<=n;i++) scanf("%lf",&f[i].a);
      for(int i=0;i<=m;i++) scanf("%lf",&g[i].a);
      int t=1;
      while(t<=n+m) t<<=1;  
      trans(f,t,0);
      trans(g,t,0);
      for(int i=0;i<t;i++) f[i]=f[i]*g[i];
      trans(f,t,1);
      for(int i=0;i<=n+m;i++) printf("%d ",(int)(f[i].a/t+0.49));   //+0.49 减小因精度产生的误差 (我也不知道为什么这样就可减小误差...)
      return 0;
    }
    

    但是, 当你把这份代码交上去后, 会发现只有 77pts, 后面两点会 TLE.

    这是因为复数运算的常数本身就比较大, 再加上递归带来的常数, 你不T谁T.

    所以, 继续下一个内容.

    FFT的优化

    复数运算带来的常数是优化不了了, 毕竟 (FFT) 的关键步骤 ---- 分治 要依靠它才能进行.

    (当然, 有人用其他更优的东西把它替代了, 不过这属于下一个内容 ---- (NTT) )

    那我们就考虑如何优化递归带来的常数吧.

    我们发现, 递归的下传过程并没有进行什么操作, 在上传过程中才处理出了点值.

    那我们可以这样理解 : 递归的下传过程就是为了寻找每个数的对应位置.

    那么, 这个对应位置是否存在某种规律, 能让我们免去递归的过程, 直接把它们放在应该放的位置?

    经过前人的不懈努力和细心观察发现, 每个数最终的位置是该数的 二进制翻转

    比如, 当 (n = 8) 的时候.

    0   1   2   3   4   5   6   7   
    0   2   4   6 | 1   3   5   7
    0   4 | 2   6 | 1   5 | 3   7
    0 | 4 | 2 | 6 | 1 | 5 | 3 | 7
    

    化为二进制就是

    000 001 010 011 100 101 110 111
    
    000 100 010 110 001 101 011 111
    

    是不是非常神奇

    然后我们可以用一个类似递归的过程来处理他们的位置

    for(int i=0;i<n;i++)
        num[i]=(num[i>>1]>>1])|((i&1) ?n>>1 :0)
    

    可以这样理解,

    假设你有一个数 (x), 它的二进制为

    xxxxxxxxxx
    

    把它拆成这两部分

    xxxxxxxxx | x
    

    前半部分的翻转, 就相当于 (x>>1) 的翻转再左移一位. (可以自己模拟一下)

    然后再根据最后一位是 (0)(1) , 在前面补上相应的一位.

    ok, 这样, 我们就避免了递归带来的常数.

    还有一个小地方

    for(int i=0;2*i<len;i++){
        tmp[i]=g1[i]+wi*g2[i];			// 上面的两个式子
        tmp[i+len/2]=g1[i]-wi*g2[i];
        wi=wi*w1;	// 处理出每个单位根
      }
    

    我们可以把它改成

    for(int i=0;2*i<len;i++){
        cn tmp=wi*g2[i];
        tmp[i]=g1[i]+tmp;			// 上面的两个式子
        tmp[i+len/2]=g1[i]-tmp;
        wi=wi*w1;	// 处理出每个单位根
      }
    

    减少了一下复数的运算量.

    最终代码 【模板】多项式乘法(FFT)

    #include<bits/stdc++.h>
    #define _USE_MATH_DEFINES
    using namespace std;
    const int N=3e6+7; 
    const double Pi=M_PI;
    struct cn{
      double a,b;
      cn operator + (const cn &x) const{
        return (cn){x.a+a,x.b+b};
      }
      cn operator - (const cn &x) const{
        return (cn){a-x.a,b-x.b};
      }
      cn operator * (const cn &x) const{
        return (cn){x.a*a-x.b*b,x.a*b+a*x.b};
      }
    };
    int n,m,t=1,num[N];
    cn f[N],g[N],tmp[N];
    void trans(cn *f,int id){
      for(int i=0;i<t;i++)
        if(i<num[i]) swap(f[i],f[num[i]]);
      for(int len=2;len<=t;len<<=1){
        int gap=len>>1;
        cn w1=(cn){cos(2*Pi/len),sin(2*Pi/len)*id};
        for(int i=0;i<t;i+=len){
          cn wj=(cn){1,0};
          for(int j=i;j<i+gap;j++){
    		cn tt=wj*f[j+gap];
    		f[j+gap]=f[j]-tt;	// 这里需要注意一下赋值的顺序
    		f[j]=f[j]+tt;
    		wj=wj*w1;
          }
        }
      }
    }
    int main(){
      //freopen("FFT.in","r",stdin);
      //freopen("x.out","w",stdout);
      cin>>n>>m;
      for(int i=0;i<=n;i++) scanf("%lf",&f[i].a);
      for(int i=0;i<=m;i++) scanf("%lf",&g[i].a);
      while(t<=n+m) t<<=1;   // 保证 t > n+m
      for(int i=1;i<t;i++) num[i]=(num[i>>1]>>1)|((i&1)?t>>1:0);
      trans(f,1);
      trans(g,1);  
      for(int i=0;i<t;i++) f[i]=f[i]*g[i];
      trans(f,-1);
      for(int i=0;i<=n+m;i++) printf("%d ",(int)(f[i].a/t+0.49));
      return 0;
    }
    
    

    upd 2020,08,16

    NOI考前复习, 敲了个封装了的 NTT, 比之前的快那么一些.

    #include <bits/stdc++.h>
    
    #define pb push_back
    typedef long long ll;
    
    using namespace std;
    
    const int _ = (1 << 21) + 7;
    const int mod = 998244353;
    const int _g = 3, _invg = 332748118;
    
    int n, m;
    vector<int> f, g;
    
    int gi() {
      int x = 0; char c = getchar();
      while (!isdigit(c)) c = getchar();
      while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
      return x;
    }
    
    namespace ploy {
      int n, invn, num[_];
      int pw(int a, int p) {
        int res = 1;
        while (p) {
          if (p & 1) res = (ll)res * a % mod;
          a = (ll)a * a % mod;
          p >>= 1;
        }
        return res;
      }
    
      void init(int x) {
        n = 1; while (n < x) n <<= 1;
        for (int i = 0; i < n; ++i)
          num[i] = (num[i >> 1] >> 1) | (i & 1 ? (n >> 1) : 0);
        invn = pw(n, mod - 2);
      }
    
      void NTT(vector<int>& f, bool ty) {
        for (int i = 0; i < n; ++i)
          if (i < num[i]) swap(f[i], f[num[i]]);
        for (int len = 2; len <= n; len <<= 1) {
          int gap = len >> 1;
          int w1 = pw(ty ? _invg : _g, (mod - 1) / len);
          for (int i = 0; i < n; i += len) {
            int w = 1;
            for (int j = i; j < i + gap; ++j, w = (ll)w * w1 % mod) {
              int tmp = (ll)w * f[j + gap] % mod;
              f[j + gap] = (ll)(f[j] - tmp + mod) % mod;
              f[j] = (ll)(f[j] + tmp) % mod;
            }
          }
        }
      }
    
      vector<int> Mul(vector<int> f, vector<int> g) {
        f.resize(n), g.resize(n);
        NTT(f, 0), NTT(g, 0);
        for (int i = 0; i < n; ++i) f[i] = (ll)f[i] * g[i] % mod;
        NTT(f, 1);
        for (int i = 0; i < n; ++i) f[i] = (ll)f[i] * invn % mod;
        return f;
      }
    }
    
    int main() {
      cin >> n >> m; ++n, ++m;
      for (int i = 0; i < n; ++i) f.pb(gi());
      for (int i = 0; i < m; ++i) g.pb(gi());
      ploy::init(n + m - 1);
      f = ploy::Mul(f, g);
      for (int i = 0; i < n + m - 1; ++i) printf("%d ", f[i]); putchar('
    ');
      return 0;
    }
    

    (因为变量名错误调了好久...)


    推荐题目

    [ZJOI2014]力

    下面三道是 (NTT) 的题.

    [AH2017/HNOI2017]礼物

    [SDOI2015]序列统计

    幼儿园篮球题

    参考资料

    傅里叶变换(FFT)学习笔记 by command_block

    对了, 还有一件事,

    Typora真好用

  • 相关阅读:
    896. 单调数列
    819. 最常见的单词
    collections.Counter()
    257. 二叉树的所有路径
    万里长征,始于足下——菜鸟程序员的学习总结(三)
    Ogre启动过程&原理
    Ogre导入模型
    四元数
    Ogre3D嵌入Qt框架
    如何搭建本地SVN服务
  • 原文地址:https://www.cnblogs.com/BruceW/p/12116397.html
Copyright © 2011-2022 走看看