zoukankan      html  css  js  c++  java
  • Fast Fourier Transform

    Introduction

    快速傅里叶变换(Fast Fourier Transform,FFT)是一种可在 (O(n log n)) 时间内完成的离散傅里叶变换 (Discrete Fourier Transform,DFT) 的算法,用来实现将信号从原始域(通常是时间或空间)到频域的互相转化。

    FFT 在算法竞赛中主要用来加速多项式乘法(循环卷积)。

    多项式

    形如

    [A(x) = a_0 + a_1x + a_2x^2 + dots + a_{n-1}x^{n - 1} ]

    的式子称为 (x)(n - 1) 次多项式,其中 (a_0, a_1, dots, a_{n - 1}) 称为多项式系数,(n-1) 称为多项式的次数,记为 (deg A(x))(deg A)

    点值

    (n - 1) 次多项式 (A(x))(x = m) 处的点值

    [A(m) = sum_{k=0}^{n-1} a_km^k ]

    多项式乘法

    (A(x) imes B(x)) 表示多项式 (A(x), B(x)) 做多项式乘法,可以简写为 (A(x)cdot B(x))(A(x)B(x))

    多项式乘法

    [egin{aligned} C(x) = &A(x) imes B(x)\ = &left(a_0 + a_1x + dots + a_{deg A}x^{deg A} ight)cdotleft(b_0 + b_1x + dots + b_{deg B}x^{deg B} ight)\ = &sum_{r = 0}^{deg A + deg B} sum_{k = 0}^r a_kb_{r - k} x^r end{aligned} ]

    用系数关系可以表示为

    [c_r = sum_{k = 0}^ra_kb_{r - k} ]

    其中 (deg C = deg A + deg B)

    易证它们的点值满足如下关系

    [C(m) = A(m)B(m) ]

    循环卷积

    (operatorname{conv}(A, B, n)) 表示多项式 (A(x), B(x)) 做长度为 (n) 的循环卷积。

    循环卷积

    [C(x) = operatorname{conv}(A, B, n) ]

    系数关系表示为

    [c_k = sum_{p, q}[(p + q) mod n = k]a_pb_q ]

    其中 (deg C = n - 1)

    容易发现,当 (n > deg A + deg B) 时,该运算等价于多项式乘法。

    DFT

    离散傅里叶变换(Discrete Fourier Transform, DFT) 将多项式 (A(x)=sum_{k=0}^{n-1}a_kx^k) 转换为一些特殊的点值。

    (n) 次单位复根

    [omega_n = e^{frac{2ipi}n}=cosdfrac{2pi}{n}+isindfrac{2pi}{n} ]

    (DFT(A)) 就是要计算点值 (A(omega_n^k), k = 0, 1, 2, dots, n-1)

    单位根自带的循环特性使得循环卷积 (C(x) = operatorname{conv}(A, B, n)) 的点值也满足:

    [C(omega_n^k) = A(omega_n^k)B(omega_n^k) ]

    IDFT

    IDFT 是 DFT 的逆变换。

    首先,用等比数列求和易证:

    [egin{align*} frac1nsum_{k = 0}^{n - 1}omega_n^{vk} &= [v mod n = 0] end{align*} ]

    考虑循环卷积 (C(x) = operatorname{conv}(A, B, n)) 的系数表示

    [egin{align*} c_r = &sum_{p, q}[(p + q) mod n = r]a_pb_q\ = &sum_{p, q}[(p + q - r) mod n = 0]a_pb_q\ = &sum_{p, q}frac1nsum_{k = 0}^{n - 1}omega_n^{pk+qk-rk}a_pb_q\ = &sum_{p, q}frac1nsum_{k = 0}^{n - 1}omega_n^{-rk}cdotomega_n^{pk}a_pcdotomega_n^{qk}b_q\ = &frac1nsum_{k = 0}^{n - 1}omega_n^{-rk}left(sum_{p}omega_n^{pk}a_psum_qomega_n^{qk}b_q ight)\ = &frac1nsum_{k = 0}^{n - 1}left(omega_n^{-r} ight)^kA(omega_n^k)B(omega_n^k)\ = &frac1nsum_{k = 0}^{n - 1}left(omega_n^{n-r} ight)^kC(omega_n^k) end{align*} ]

    设多项式

    [C'(x) = sum_{k=0}^{n-1}C(omega_n^k)x^k ]

    只要计算 (DFT(C')) 即可得到 (C(x)) 的系数,于是我们用 DFT 完成了逆变换 IDFT。

    用两次 DFT 和一次 IDFT就可以计算 (operatorname{conv}(A, B, n))

    暴力的复杂度是 (O(n^2)),此处不赘述。

    FFT

    现在尝试将 DFT 问题分解以优化时间复杂度。

    本部分认为 (n = deg A + 1)(2) 的整数次幂。对于更一般的情况,暂不考虑。

    DIF

    将序列 (a_i) 分成左右两半

    [egin{align*} A(omega_n^{r}) &= sum_{k = 0}^{n-1}a_komega_n^{rk}\ &= sum_{k = 0}^{n / 2 - 1} left(a_kcdotomega_n^{rk} + a_{k+n/2}cdotomega_n^{rk+rn/2} ight)\ &= sum_{k = 0}^{n / 2 - 1} left[a_kcdotomega_n^{rk} + (-1)^rcdot a_{k+n/2}cdotomega_n^{rk} ight]\ &= sum_{k = 0}^{n / 2 - 1} left[a_k+(-1)^ra_{k+n/2} ight]omega_{n}^{rk} end{align*} ]

    进一步,将 (A(omega_{n}^r)) 按奇偶分类

    [egin{align*} Aleft(omega_n^{2r} ight) &= sum_{k=0}^{n/2-1}left(a_k+a_{k+n/2} ight)omega_{n/2}^{rk}\ Aleft(omega_n^{2r+1} ight) &= sum_{k=0}^{n/2-1}left(omega_{n}^ka_k-omega_{n}^ka_{k+n/2} ight)omega_{n/2}^{rk} end{align*} ]

    [egin{align*} &p_k=a_k+a_{k+n/2}, &P(x) = sum_{k = 0}^{n/2-1}p_kx^k\ &q_k=omega_{n}^k(a_k-a_{k+n/2}), &Q(x) = sum_{k=0}^{n/2-1}q_kx^k end{align*} ]

    我们只需要求出 (P(omega_{n/2}^r))(Q(omega_{n/2}^r)) ,即求解规模为原来一半的两个子问题 (DFT(P), DFT(Q)),就能在 (O(n)) 时间内计算出 (DFT(A))

    DIT

    在算法竞赛中这种方法更常见。

    注意到在 DIF 中我们最后将 (A(omega_n^r)) 奇偶分类求解,那不妨思考将序列 (a_k) 按奇偶分类

    [A_0(x) = a_0 + a_2x + dots + a_{n - 2}x^{n / 2}\ A_1(x) = a_1 + a_3x+ dots + a_{n - 1}x^{n / 2} ]

    [A(x) = A_0(x^2) + xA_1(x^2) ]

    所以

    [egin{align*} A(omega_n^k) &= A_0(omega_n^{2k}) + omega_n^kA_1(omega_n^{2k})\ &= A_0(omega_{n/2}^k) + omega_n^kA_1(omega_{n/2}^k) end{align*} ]

    (A(omega_n^k)) 再细分为左右两半,这里运用了等式 (omega_{n/2}^k = omega_{n/2}^{k + n/2})(omega_n^k+omega_n{k+n/2} = 0) :

    [egin{align*} A(omega_n^k) &= A_0(omega_{n/2}^k) + omega_n^kA_1(omega_{n/2}^k)\ Aleft(omega_n^{k+n/2} ight) &= A_0(omega_{n/2}^k) - omega_n^kA_1(omega_{n/2}^k) end{align*} ]

    我们只需要求出 (A_0(omega_{n/2}^k))(A_1(omega_{n/2}^k)) ,即求解规模为原来一半的两个子问题 (DFT(A_0), DFT(A_1)),就能在 (O(n)) 时间内计算出 (DFT(A))

    Complexity

    设次数为 (n - 1) 的多项式做 DFT 的时间复杂度为 (T(n)),则

    [T(n) = 2T(frac{n}{2}) + O(n) ]

    根据主定理

    [T(n) = O(n log n) ]

    Implementation

    上述两种计算方式均可以使用递归实现,这里直接给出代码,不再赘述。
    DIF

    const double PI = acos(-1.0);
    void dft(std::vector<Complex> &a) {
      int n = a.size(), m = n >> 1;
      if (n == 1) return;
      std::vector<Complex> p(m), q(m);
      for (int i = 0; i < m; i++) {
        p[i] = a[i] + a[i + m];
        q[i] = (a[i] - a[i + m]) * Complex(cos(2 * PI * i / n), sin(2 * PI * i / n));
      }
      dft(p), dft(q);
      for (int i = 0; i < m; i++)
        a[i << 1] = p[i], a[i << 1 | 1] = q[i];
    }
    void idft(std::vector<Complex> &a) {
      dft(a);
      for (auto &v: a) v.a /= a.size(), v.b /= a.size();
      std::reverse(a.begin() + 1, a.end());
    }
    

    DIT

    const double PI = acos(-1.0);
    void dft(std::vector<Complex> &a) {
      int n = a.size(), m = n >> 1;
      if (n == 1) return;
      std::vector<Complex> p(m), q(m);
      for (int i = 0; i < m; i++) {
        p[i] = a[i << 1];
        q[i] = a[i << 1 | 1];
      }
      dft(p), dft(q);
      for (int i = 0; i < m; i++) {
        Complex &u = p[i], v = Complex(cos(2 * PI * i / n), sin(2 * PI * i / n)) * q[i];
        a[i] = u + v, a[i + m] = u - v;
      }
    }
    void idft(std::vector<Complex> &a) {
      dft(a);
      for (auto &v: a) v.a /= a.size(), v.b /= a.size();
      std::reverse(a.begin() + 1, a.end());
    }
    

    下面探讨以 非递归方式 实现 DIFDIT

    DIT

    由于 DIT 更易于理解(其实只是资料多),先讲这个。

    考虑递归的过程:

    [egin{align*} &0. (a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7)\ &1. (a_0, a_2, a_4, a_6)(a_1, a_3, a_5, a_7)\ &2. (a_0, a_4)(a_2, a_6)(a_1, a_5)(a_3, a_7)\ &3. (a_0)(a_4)(a_2)(a_6)(a_1)(a_5)(a_3)(a_7)\ &2' (A_{00}, A_{01})(A_{10},A_{11})(A_{00},A_{01})(A_{10}, A_{11})\ &1' (A_{00}, A_{01}, A_{02}, A_{03})(A_{10}, A_{11}, A_{12}, A_{13})\ &0' (A_0, A_1, A_2, A_3, A_4, A_5, A_6, A_7) end{align*} ]

    发现 (0 ightarrow 3) 只是在重新安排数据位置,并没有修改数据,如果我们能把映射关系找到,那就可以一步到位,直接从 (3) 开始。

    设一个数在第 (i) 个阶段 ((0 leq i leq log_2n)) 的位置为 (p_i),相对位置为 (p_i')相对位置指它在括号里的位置,例如上面第 (1) 阶段 (a_1) 的相对位置为 (0))。

    容易发现

    [p'_i = p_i mod frac{n}{2^i}\ p_{i + 1} = p_{i} - p'_i + egin{cases} dfrac{n}{2^{i+1}}+dfrac{p'_i-1}{2} & p_i' equiv 1 pmod 2\ dfrac{p'_i}{2} & p_i' equiv 0 pmod 2 end{cases} ]

    如果将 (p_0) 写成二进制 (overline{b_4b_3b_2b_1b_0})(这里以 (n = 32) 为例),那么单次变化的过程相当于把二进制的后几位向右 rotate 一位,总的变化过程可以描述为:

    [overline{|b_4b_3b_2b_1b_0} ightarrow overline{|b_0b_4b_3b_2b_1} \ overline{b_0|b_4b_3b_2b_1} ightarrow overline{b_0|b_1b_4b_3b_2} \ overline{b_0b_1|b_4b_3b_2} ightarrow overline{b_0b_1|b_2b_4b_3} \ overline{b_0b_1b_2|b_4b_3} ightarrow overline{b_0b_1b_2|b_3b_4} \ overline{b_0b_1b_2b_3|b_4} ightarrow overline{b_0b_1b_2b_3|b_4} ]

    可以发现,整个过程实际上是在做 reverse 操作!

    至此我们找到了映射关系,成功把前面的步骤都砍掉了,只剩回溯,可以改成循环。

    void dft(std::vector<Complex> &a) {
      int n = a.size();
      for (int i = 0, j = 0; i < n; i++) {
        if (i > j) std::swap(a[i], a[j]);
        for (int k = n >> 1; (j ^= k) < k; k >>= 1);
      }
      for (int k = 1; k < n; k <<= 1) {
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            auto t = a[i + j + k] * Complex(cos(PI * j / k), sin(PI * j / k));
            a[i + j + k] = a[i + j] - t;
            a[i + j] = a[i + j] + t;
          }
        }
      }
    }
    

    DIF

    这里先将递归版DIF的过程简单复述:

    第一步:将序列 (a) 对半分

    [egin{align*} &p_k=a_k+a_{k+n/2}, &P(x) = sum_{k = 0}^{n/2-1}p_kx^k\ &q_k=omega_{n}^k(a_k-a_{k+n/2}), &Q(x) = sum_{k=0}^{n/2-1}q_kx^k end{align*} ]

    第二步:递归计算 (DFT(p), DFT(q))
    第三步:重新安排数据位置

    [egin{align*} Aleft(omega_n^{2r} ight) &= P(omega_{n/2}^r)\ Aleft(omega_n^{2r+1} ight) &= Q(omega_{n/2}^r) end{align*} ]

    发现回溯的过程(即第三步)实际上也只是在重新安排数据存储的位置,而且是上面 (DIT) 第一步的逆过程,所以就是位翻转的逆过程,所以还是位翻转。

    所以最后安排数据位置可以一步搞定,只剩递归压栈的过程,可以改成循环。

    void dft(vector<Complex> &a) {
      int n = a.size();
      for (int k = n >> 1; k; k >>= 1) {
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            auto t = a[i + j + k];
            a[i + j + k] = (a[i + j] - t) * Complex(cos(PI * j / k), sin(PI * j / k));
            a[i + j] = a[i + j] + t;
          }
        }
      }
      for (int i = 0, j = 0; i < n; i++) {
        if (i > j) std::swap(a[i], a[j]);
        for (int k = n >> 1; (j ^= k) < k; k >>= 1);
      }
    }
    

    Combination

    发现 (DIF) 的最后一步和 (DIT) 的第一步都是位翻转,所以先 (DIF)(DIT),就可以省略位翻转。

    完整代码

    #include <bits/stdc++.h>
    
    template <class T>
    inline void readInt(T &w) {
      char c, p = 0;
      while (!isdigit(c = getchar())) p = c == '-';
      for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
      if (p) w = -w;
    }
    
    struct Complex {
      double a, b; // a + bi
      Complex(double a = 0, double b = 0): a(a), b(b) {}
    };
    inline Complex operator+(const Complex &p, const Complex &q) {
      return Complex(p.a + q.a, p.b + q.b);
    }
    inline Complex operator-(const Complex &p, const Complex &q) {
      return Complex(p.a - q.a, p.b - q.b);
    }
    inline Complex operator*(const Complex &p, const Complex &q) {
      return Complex(p.a * q.a - p.b * q.b, p.a * q.b + p.b * q.a);
    }
    
    const double PI = acos(-1.0);
    void dft(std::vector<Complex> &a) {
      int n = a.size();
      for (int k = n >> 1; k; k >>= 1) {
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            auto t = a[i + j + k];
            a[i + j + k] = (a[i + j] - t) * Complex(cos(PI * j / k), sin(PI * j / k));
            a[i + j] = a[i + j] + t;
          }
        }
      }
    }
    void idft(std::vector<Complex> &a) {
      int n = a.size();
      for (int k = 1; k < n; k <<= 1) {
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            auto t = a[i + j + k] * Complex(cos(PI * j / k), sin(PI * j / k));
            a[i + j + k] = a[i + j] - t;
            a[i + j] = a[i + j] + t;
          }
        }
      }
      for (auto &v: a) v.a /= a.size(), v.b /= a.size();
      std::reverse(a.begin() + 1, a.end());
    }
    int main() {
      int n, m, k;
      readInt(n), readInt(m);
      k = 1 << std::__lg(n + m) + 1;
      std::vector<Complex> a(k), b(k), c(k);
      for (int i = 0; i <= n; i++) readInt(a[i].a);
      for (int i = 0; i <= m; i++) readInt(b[i].a);
      dft(a), dft(b);
      for (int i = 0; i < k; i++) c[i] = a[i] * b[i];
      idft(c);
      for (int i = 0; i <= n + m; i++) printf("%d ", (int)(c[i].a + 0.5));
      return 0;
    }
    

    Number Theory Transform

    如果一个质数存在 (2^n) 次单位根(其中 (n) 最大时的单位根称为原根),那么在这个质数的剩余系下上面的结论依旧成立,可以使用FFT,多称这种FFT为快速数论变换(Number Theory Transform, NTT)

    常见的质数是 (P = 998244353),它的原根 (g = 3)

    代码(预处理单位根,较好地平衡了代码复杂度和常数且有一定的封装度):

    #include <bits/stdc++.h>
    template <class T>
    inline void readInt(T &w) {
      char c, p = 0;
      while (!isdigit(c = getchar())) p = c == '-';
      for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
      if (p) w = -w;
    }
    template <class T, class... U>
    inline void readInt(T &w, U &... a) { readInt(w), readInt(a...); }
    
    constexpr int P(998244353), G(3);
    inline void inc(int &x, int y) { (x += y) >= P ? x -= P : 0; }
    inline int sum(int x, int y) { return x + y >= P ? x + y - P : x + y; }
    inline int sub(int x, int y) { return x - y < 0 ? x - y + P : x - y; }
    inline int fpow(int x, int k = P - 2) {
      int r = 1;
      for (; k; k >>= 1, x = 1LL * x * x % P)
        if (k & 1) r = 1LL * r * x % P;
      return r;
    }
    
    
    namespace Polynomial {
    using Polynom = std::vector<int>;
    int n;
    std::vector<int> w;
    void getOmega(int k) {
      w.resize(k);
      w[0] = 1;
      int base = fpow(G, (P - 1) / (k << 1));
      for (int i = 1; i < k; i++) w[i] = 1LL * w[i - 1] * base % P;
    }
    void dft(Polynom &a) {
      for (int k = n >> 1; k; k >>= 1) {
        getOmega(k);
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            int y = a[i + j + k];
            a[i + j + k] = (1LL * a[i + j] - y + P) * w[j] % P;
            inc(a[i + j], y);
          }
        }
      }
    }
    void idft(Polynom &a) {
      for (int k = 1; k < n; k <<= 1) {
        getOmega(k);
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            int x = a[i + j], y = 1LL * a[i + j + k] * w[j] % P;
            a[i + j] = sum(x, y);
            a[i + j + k] = sub(x, y);
          }
        }
      }
      int inv = fpow(n);
      for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * inv % P;
      std::reverse(a.begin() + 1, a.end());
    }
    } // namespace Polynom
    using Polynomial::dft;
    using Polynomial::idft;
    void poly_multiply(unsigned *A, int n, unsigned *B, int m, unsigned *C) {
      int k = Polynomial::n = 1 << std::__lg(n + m) + 1;
      std::vector<int> a(k), b(k);
      for (int i = 0; i <= n; i++) a[i] = A[i];
      for (int i = 0; i <= m; i++) b[i] = B[i];
      dft(a), dft(b);
      for (int i = 0; i < k; i++) a[i] = 1LL * a[i] * b[i] % P;
      idft(a);
      for (int i = 0; i <= n + m; i++) C[i] = a[i];
    }
    int main() {
      int n, m, k;
      readInt(n, m);
      Polynomial::n = k = 1 << std::__lg(n + m) + 1;
      std::vector<int> a(k), b(k);
      for (int i = 0; i <= n; i++) readInt(a[i]);
      for (int i = 0; i <= m; i++) readInt(b[i]);
      dft(a), dft(b);
      for (int i = 0; i < k; i++) a[i] = 1LL * a[i] * b[i] % P;
      idft(a);
      for (int i = 0; i <= n + m; i++) printf("%d ", a[i]);
      return 0;
    }
    
    

    Bluestein’s Algorithm

    上面提到的 FFT 算法虽然限制了 (n) 为 2 的次幂,但在大多数情况下已经足够解决问题。

    对于更一般的 (n) 需要用到 Bluestein’s Algorithm ,可以参考2016年国家集训队论文《再探快速傅里叶变换——毛啸》。

    后续可能会填这个坑。

  • 相关阅读:
    tlb、tlh和tli文件的关系
    String算法
    Reverse A String by STL string
    windows内存管理复习(加深了理解得很!)
    [转载]有关DLL中New和外部Delete以以及跨DLL传递对象的若干问题
    顺势工作时间
    C++箴言:绝不在构造或析构期调用虚函数
    inline函数复习
    从编译器的角度更加深入考虑封装的使用
    复习:constructor和destructor的compiler实现
  • 原文地址:https://www.cnblogs.com/HolyK/p/13991949.html
Copyright © 2011-2022 走看看