zoukankan      html  css  js  c++  java
  • Fast Fourier Transform

    Introduction

    快速傅里叶变换(Fast Fourier Transform,FFT)是一种可在 (O(n log n)) 时间内完成的离散傅里叶变换 (Discrete Fourier Transform,DFT) 的算法,用来实现将信号从原始域(通常是时间或空间)到频域的互相转化。

    FFT 在算法竞赛中主要用来加速多项式乘法(循环卷积)。

    多项式

    形如

    [A(x) = a_0 + a_1x + a_2x^2 + dots + a_{n-1}x^{n - 1} ]

    的式子称为 (x)(n - 1) 次多项式,其中 (a_0, a_1, dots, a_{n - 1}) 称为多项式系数,(n-1) 称为多项式的次数,记为 (deg A(x))(deg A)

    点值

    (n - 1) 次多项式 (A(x))(x = m) 处的点值

    [A(m) = sum_{k=0}^{n-1} a_km^k ]

    多项式乘法

    (A(x) imes B(x)) 表示多项式 (A(x), B(x)) 做多项式乘法,可以简写为 (A(x)cdot B(x))(A(x)B(x))

    多项式乘法

    [egin{aligned} C(x) = &A(x) imes B(x)\ = &left(a_0 + a_1x + dots + a_{deg A}x^{deg A} ight)cdotleft(b_0 + b_1x + dots + b_{deg B}x^{deg B} ight)\ = &sum_{r = 0}^{deg A + deg B} sum_{k = 0}^r a_kb_{r - k} x^r end{aligned} ]

    用系数关系可以表示为

    [c_r = sum_{k = 0}^ra_kb_{r - k} ]

    其中 (deg C = deg A + deg B)

    易证它们的点值满足如下关系

    [C(m) = A(m)B(m) ]

    循环卷积

    (operatorname{conv}(A, B, n)) 表示多项式 (A(x), B(x)) 做长度为 (n) 的循环卷积。

    循环卷积

    [C(x) = operatorname{conv}(A, B, n) ]

    系数关系表示为

    [c_k = sum_{p, q}[(p + q) mod n = k]a_pb_q ]

    其中 (deg C = n - 1)

    容易发现,当 (n > deg A + deg B) 时,该运算等价于多项式乘法。

    DFT

    离散傅里叶变换(Discrete Fourier Transform, DFT) 将多项式 (A(x)=sum_{k=0}^{n-1}a_kx^k) 转换为一些特殊的点值。

    (n) 次单位复根

    [omega_n = e^{frac{2ipi}n}=cosdfrac{2pi}{n}+isindfrac{2pi}{n} ]

    (DFT(A)) 就是要计算点值 (A(omega_n^k), k = 0, 1, 2, dots, n-1)

    单位根自带的循环特性使得循环卷积 (C(x) = operatorname{conv}(A, B, n)) 的点值也满足:

    [C(omega_n^k) = A(omega_n^k)B(omega_n^k) ]

    IDFT

    IDFT 是 DFT 的逆变换。

    首先,用等比数列求和易证:

    [egin{align*} frac1nsum_{k = 0}^{n - 1}omega_n^{vk} &= [v mod n = 0] end{align*} ]

    考虑循环卷积 (C(x) = operatorname{conv}(A, B, n)) 的系数表示

    [egin{align*} c_r = &sum_{p, q}[(p + q) mod n = r]a_pb_q\ = &sum_{p, q}[(p + q - r) mod n = 0]a_pb_q\ = &sum_{p, q}frac1nsum_{k = 0}^{n - 1}omega_n^{pk+qk-rk}a_pb_q\ = &sum_{p, q}frac1nsum_{k = 0}^{n - 1}omega_n^{-rk}cdotomega_n^{pk}a_pcdotomega_n^{qk}b_q\ = &frac1nsum_{k = 0}^{n - 1}omega_n^{-rk}left(sum_{p}omega_n^{pk}a_psum_qomega_n^{qk}b_q ight)\ = &frac1nsum_{k = 0}^{n - 1}left(omega_n^{-r} ight)^kA(omega_n^k)B(omega_n^k)\ = &frac1nsum_{k = 0}^{n - 1}left(omega_n^{n-r} ight)^kC(omega_n^k) end{align*} ]

    设多项式

    [C'(x) = sum_{k=0}^{n-1}C(omega_n^k)x^k ]

    只要计算 (DFT(C')) 即可得到 (C(x)) 的系数,于是我们用 DFT 完成了逆变换 IDFT。

    用两次 DFT 和一次 IDFT就可以计算 (operatorname{conv}(A, B, n))

    暴力的复杂度是 (O(n^2)),此处不赘述。

    FFT

    现在尝试将 DFT 问题分解以优化时间复杂度。

    本部分认为 (n = deg A + 1)(2) 的整数次幂。对于更一般的情况,暂不考虑。

    DIF

    将序列 (a_i) 分成左右两半

    [egin{align*} A(omega_n^{r}) &= sum_{k = 0}^{n-1}a_komega_n^{rk}\ &= sum_{k = 0}^{n / 2 - 1} left(a_kcdotomega_n^{rk} + a_{k+n/2}cdotomega_n^{rk+rn/2} ight)\ &= sum_{k = 0}^{n / 2 - 1} left[a_kcdotomega_n^{rk} + (-1)^rcdot a_{k+n/2}cdotomega_n^{rk} ight]\ &= sum_{k = 0}^{n / 2 - 1} left[a_k+(-1)^ra_{k+n/2} ight]omega_{n}^{rk} end{align*} ]

    进一步,将 (A(omega_{n}^r)) 按奇偶分类

    [egin{align*} Aleft(omega_n^{2r} ight) &= sum_{k=0}^{n/2-1}left(a_k+a_{k+n/2} ight)omega_{n/2}^{rk}\ Aleft(omega_n^{2r+1} ight) &= sum_{k=0}^{n/2-1}left(omega_{n}^ka_k-omega_{n}^ka_{k+n/2} ight)omega_{n/2}^{rk} end{align*} ]

    [egin{align*} &p_k=a_k+a_{k+n/2}, &P(x) = sum_{k = 0}^{n/2-1}p_kx^k\ &q_k=omega_{n}^k(a_k-a_{k+n/2}), &Q(x) = sum_{k=0}^{n/2-1}q_kx^k end{align*} ]

    我们只需要求出 (P(omega_{n/2}^r))(Q(omega_{n/2}^r)) ,即求解规模为原来一半的两个子问题 (DFT(P), DFT(Q)),就能在 (O(n)) 时间内计算出 (DFT(A))

    DIT

    在算法竞赛中这种方法更常见。

    注意到在 DIF 中我们最后将 (A(omega_n^r)) 奇偶分类求解,那不妨思考将序列 (a_k) 按奇偶分类

    [A_0(x) = a_0 + a_2x + dots + a_{n - 2}x^{n / 2}\ A_1(x) = a_1 + a_3x+ dots + a_{n - 1}x^{n / 2} ]

    [A(x) = A_0(x^2) + xA_1(x^2) ]

    所以

    [egin{align*} A(omega_n^k) &= A_0(omega_n^{2k}) + omega_n^kA_1(omega_n^{2k})\ &= A_0(omega_{n/2}^k) + omega_n^kA_1(omega_{n/2}^k) end{align*} ]

    (A(omega_n^k)) 再细分为左右两半,这里运用了等式 (omega_{n/2}^k = omega_{n/2}^{k + n/2})(omega_n^k+omega_n{k+n/2} = 0) :

    [egin{align*} A(omega_n^k) &= A_0(omega_{n/2}^k) + omega_n^kA_1(omega_{n/2}^k)\ Aleft(omega_n^{k+n/2} ight) &= A_0(omega_{n/2}^k) - omega_n^kA_1(omega_{n/2}^k) end{align*} ]

    我们只需要求出 (A_0(omega_{n/2}^k))(A_1(omega_{n/2}^k)) ,即求解规模为原来一半的两个子问题 (DFT(A_0), DFT(A_1)),就能在 (O(n)) 时间内计算出 (DFT(A))

    Complexity

    设次数为 (n - 1) 的多项式做 DFT 的时间复杂度为 (T(n)),则

    [T(n) = 2T(frac{n}{2}) + O(n) ]

    根据主定理

    [T(n) = O(n log n) ]

    Implementation

    上述两种计算方式均可以使用递归实现,这里直接给出代码,不再赘述。
    DIF

    const double PI = acos(-1.0);
    void dft(std::vector<Complex> &a) {
      int n = a.size(), m = n >> 1;
      if (n == 1) return;
      std::vector<Complex> p(m), q(m);
      for (int i = 0; i < m; i++) {
        p[i] = a[i] + a[i + m];
        q[i] = (a[i] - a[i + m]) * Complex(cos(2 * PI * i / n), sin(2 * PI * i / n));
      }
      dft(p), dft(q);
      for (int i = 0; i < m; i++)
        a[i << 1] = p[i], a[i << 1 | 1] = q[i];
    }
    void idft(std::vector<Complex> &a) {
      dft(a);
      for (auto &v: a) v.a /= a.size(), v.b /= a.size();
      std::reverse(a.begin() + 1, a.end());
    }
    

    DIT

    const double PI = acos(-1.0);
    void dft(std::vector<Complex> &a) {
      int n = a.size(), m = n >> 1;
      if (n == 1) return;
      std::vector<Complex> p(m), q(m);
      for (int i = 0; i < m; i++) {
        p[i] = a[i << 1];
        q[i] = a[i << 1 | 1];
      }
      dft(p), dft(q);
      for (int i = 0; i < m; i++) {
        Complex &u = p[i], v = Complex(cos(2 * PI * i / n), sin(2 * PI * i / n)) * q[i];
        a[i] = u + v, a[i + m] = u - v;
      }
    }
    void idft(std::vector<Complex> &a) {
      dft(a);
      for (auto &v: a) v.a /= a.size(), v.b /= a.size();
      std::reverse(a.begin() + 1, a.end());
    }
    

    下面探讨以 非递归方式 实现 DIFDIT

    DIT

    由于 DIT 更易于理解(其实只是资料多),先讲这个。

    考虑递归的过程:

    [egin{align*} &0. (a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7)\ &1. (a_0, a_2, a_4, a_6)(a_1, a_3, a_5, a_7)\ &2. (a_0, a_4)(a_2, a_6)(a_1, a_5)(a_3, a_7)\ &3. (a_0)(a_4)(a_2)(a_6)(a_1)(a_5)(a_3)(a_7)\ &2' (A_{00}, A_{01})(A_{10},A_{11})(A_{00},A_{01})(A_{10}, A_{11})\ &1' (A_{00}, A_{01}, A_{02}, A_{03})(A_{10}, A_{11}, A_{12}, A_{13})\ &0' (A_0, A_1, A_2, A_3, A_4, A_5, A_6, A_7) end{align*} ]

    发现 (0 ightarrow 3) 只是在重新安排数据位置,并没有修改数据,如果我们能把映射关系找到,那就可以一步到位,直接从 (3) 开始。

    设一个数在第 (i) 个阶段 ((0 leq i leq log_2n)) 的位置为 (p_i),相对位置为 (p_i')相对位置指它在括号里的位置,例如上面第 (1) 阶段 (a_1) 的相对位置为 (0))。

    容易发现

    [p'_i = p_i mod frac{n}{2^i}\ p_{i + 1} = p_{i} - p'_i + egin{cases} dfrac{n}{2^{i+1}}+dfrac{p'_i-1}{2} & p_i' equiv 1 pmod 2\ dfrac{p'_i}{2} & p_i' equiv 0 pmod 2 end{cases} ]

    如果将 (p_0) 写成二进制 (overline{b_4b_3b_2b_1b_0})(这里以 (n = 32) 为例),那么单次变化的过程相当于把二进制的后几位向右 rotate 一位,总的变化过程可以描述为:

    [overline{|b_4b_3b_2b_1b_0} ightarrow overline{|b_0b_4b_3b_2b_1} \ overline{b_0|b_4b_3b_2b_1} ightarrow overline{b_0|b_1b_4b_3b_2} \ overline{b_0b_1|b_4b_3b_2} ightarrow overline{b_0b_1|b_2b_4b_3} \ overline{b_0b_1b_2|b_4b_3} ightarrow overline{b_0b_1b_2|b_3b_4} \ overline{b_0b_1b_2b_3|b_4} ightarrow overline{b_0b_1b_2b_3|b_4} ]

    可以发现,整个过程实际上是在做 reverse 操作!

    至此我们找到了映射关系,成功把前面的步骤都砍掉了,只剩回溯,可以改成循环。

    void dft(std::vector<Complex> &a) {
      int n = a.size();
      for (int i = 0, j = 0; i < n; i++) {
        if (i > j) std::swap(a[i], a[j]);
        for (int k = n >> 1; (j ^= k) < k; k >>= 1);
      }
      for (int k = 1; k < n; k <<= 1) {
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            auto t = a[i + j + k] * Complex(cos(PI * j / k), sin(PI * j / k));
            a[i + j + k] = a[i + j] - t;
            a[i + j] = a[i + j] + t;
          }
        }
      }
    }
    

    DIF

    这里先将递归版DIF的过程简单复述:

    第一步:将序列 (a) 对半分

    [egin{align*} &p_k=a_k+a_{k+n/2}, &P(x) = sum_{k = 0}^{n/2-1}p_kx^k\ &q_k=omega_{n}^k(a_k-a_{k+n/2}), &Q(x) = sum_{k=0}^{n/2-1}q_kx^k end{align*} ]

    第二步:递归计算 (DFT(p), DFT(q))
    第三步:重新安排数据位置

    [egin{align*} Aleft(omega_n^{2r} ight) &= P(omega_{n/2}^r)\ Aleft(omega_n^{2r+1} ight) &= Q(omega_{n/2}^r) end{align*} ]

    发现回溯的过程(即第三步)实际上也只是在重新安排数据存储的位置,而且是上面 (DIT) 第一步的逆过程,所以就是位翻转的逆过程,所以还是位翻转。

    所以最后安排数据位置可以一步搞定,只剩递归压栈的过程,可以改成循环。

    void dft(vector<Complex> &a) {
      int n = a.size();
      for (int k = n >> 1; k; k >>= 1) {
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            auto t = a[i + j + k];
            a[i + j + k] = (a[i + j] - t) * Complex(cos(PI * j / k), sin(PI * j / k));
            a[i + j] = a[i + j] + t;
          }
        }
      }
      for (int i = 0, j = 0; i < n; i++) {
        if (i > j) std::swap(a[i], a[j]);
        for (int k = n >> 1; (j ^= k) < k; k >>= 1);
      }
    }
    

    Combination

    发现 (DIF) 的最后一步和 (DIT) 的第一步都是位翻转,所以先 (DIF)(DIT),就可以省略位翻转。

    完整代码

    #include <bits/stdc++.h>
    
    template <class T>
    inline void readInt(T &w) {
      char c, p = 0;
      while (!isdigit(c = getchar())) p = c == '-';
      for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
      if (p) w = -w;
    }
    
    struct Complex {
      double a, b; // a + bi
      Complex(double a = 0, double b = 0): a(a), b(b) {}
    };
    inline Complex operator+(const Complex &p, const Complex &q) {
      return Complex(p.a + q.a, p.b + q.b);
    }
    inline Complex operator-(const Complex &p, const Complex &q) {
      return Complex(p.a - q.a, p.b - q.b);
    }
    inline Complex operator*(const Complex &p, const Complex &q) {
      return Complex(p.a * q.a - p.b * q.b, p.a * q.b + p.b * q.a);
    }
    
    const double PI = acos(-1.0);
    void dft(std::vector<Complex> &a) {
      int n = a.size();
      for (int k = n >> 1; k; k >>= 1) {
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            auto t = a[i + j + k];
            a[i + j + k] = (a[i + j] - t) * Complex(cos(PI * j / k), sin(PI * j / k));
            a[i + j] = a[i + j] + t;
          }
        }
      }
    }
    void idft(std::vector<Complex> &a) {
      int n = a.size();
      for (int k = 1; k < n; k <<= 1) {
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            auto t = a[i + j + k] * Complex(cos(PI * j / k), sin(PI * j / k));
            a[i + j + k] = a[i + j] - t;
            a[i + j] = a[i + j] + t;
          }
        }
      }
      for (auto &v: a) v.a /= a.size(), v.b /= a.size();
      std::reverse(a.begin() + 1, a.end());
    }
    int main() {
      int n, m, k;
      readInt(n), readInt(m);
      k = 1 << std::__lg(n + m) + 1;
      std::vector<Complex> a(k), b(k), c(k);
      for (int i = 0; i <= n; i++) readInt(a[i].a);
      for (int i = 0; i <= m; i++) readInt(b[i].a);
      dft(a), dft(b);
      for (int i = 0; i < k; i++) c[i] = a[i] * b[i];
      idft(c);
      for (int i = 0; i <= n + m; i++) printf("%d ", (int)(c[i].a + 0.5));
      return 0;
    }
    

    Number Theory Transform

    如果一个质数存在 (2^n) 次单位根(其中 (n) 最大时的单位根称为原根),那么在这个质数的剩余系下上面的结论依旧成立,可以使用FFT,多称这种FFT为快速数论变换(Number Theory Transform, NTT)

    常见的质数是 (P = 998244353),它的原根 (g = 3)

    代码(预处理单位根,较好地平衡了代码复杂度和常数且有一定的封装度):

    #include <bits/stdc++.h>
    template <class T>
    inline void readInt(T &w) {
      char c, p = 0;
      while (!isdigit(c = getchar())) p = c == '-';
      for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
      if (p) w = -w;
    }
    template <class T, class... U>
    inline void readInt(T &w, U &... a) { readInt(w), readInt(a...); }
    
    constexpr int P(998244353), G(3);
    inline void inc(int &x, int y) { (x += y) >= P ? x -= P : 0; }
    inline int sum(int x, int y) { return x + y >= P ? x + y - P : x + y; }
    inline int sub(int x, int y) { return x - y < 0 ? x - y + P : x - y; }
    inline int fpow(int x, int k = P - 2) {
      int r = 1;
      for (; k; k >>= 1, x = 1LL * x * x % P)
        if (k & 1) r = 1LL * r * x % P;
      return r;
    }
    
    
    namespace Polynomial {
    using Polynom = std::vector<int>;
    int n;
    std::vector<int> w;
    void getOmega(int k) {
      w.resize(k);
      w[0] = 1;
      int base = fpow(G, (P - 1) / (k << 1));
      for (int i = 1; i < k; i++) w[i] = 1LL * w[i - 1] * base % P;
    }
    void dft(Polynom &a) {
      for (int k = n >> 1; k; k >>= 1) {
        getOmega(k);
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            int y = a[i + j + k];
            a[i + j + k] = (1LL * a[i + j] - y + P) * w[j] % P;
            inc(a[i + j], y);
          }
        }
      }
    }
    void idft(Polynom &a) {
      for (int k = 1; k < n; k <<= 1) {
        getOmega(k);
        for (int i = 0; i < n; i += k << 1) {
          for (int j = 0; j < k; j++) {
            int x = a[i + j], y = 1LL * a[i + j + k] * w[j] % P;
            a[i + j] = sum(x, y);
            a[i + j + k] = sub(x, y);
          }
        }
      }
      int inv = fpow(n);
      for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * inv % P;
      std::reverse(a.begin() + 1, a.end());
    }
    } // namespace Polynom
    using Polynomial::dft;
    using Polynomial::idft;
    void poly_multiply(unsigned *A, int n, unsigned *B, int m, unsigned *C) {
      int k = Polynomial::n = 1 << std::__lg(n + m) + 1;
      std::vector<int> a(k), b(k);
      for (int i = 0; i <= n; i++) a[i] = A[i];
      for (int i = 0; i <= m; i++) b[i] = B[i];
      dft(a), dft(b);
      for (int i = 0; i < k; i++) a[i] = 1LL * a[i] * b[i] % P;
      idft(a);
      for (int i = 0; i <= n + m; i++) C[i] = a[i];
    }
    int main() {
      int n, m, k;
      readInt(n, m);
      Polynomial::n = k = 1 << std::__lg(n + m) + 1;
      std::vector<int> a(k), b(k);
      for (int i = 0; i <= n; i++) readInt(a[i]);
      for (int i = 0; i <= m; i++) readInt(b[i]);
      dft(a), dft(b);
      for (int i = 0; i < k; i++) a[i] = 1LL * a[i] * b[i] % P;
      idft(a);
      for (int i = 0; i <= n + m; i++) printf("%d ", a[i]);
      return 0;
    }
    
    

    Bluestein’s Algorithm

    上面提到的 FFT 算法虽然限制了 (n) 为 2 的次幂,但在大多数情况下已经足够解决问题。

    对于更一般的 (n) 需要用到 Bluestein’s Algorithm ,可以参考2016年国家集训队论文《再探快速傅里叶变换——毛啸》。

    后续可能会填这个坑。

  • 相关阅读:
    【Codeforces 349B】Color the Fence
    【Codeforces 459D】Pashmak and Parmida's problem
    【Codeforces 467C】George and Job
    【Codeforces 161D】Distance in Tree
    【Codeforces 522A】Reposts
    【Codeforces 225C】Barcode
    【Codeforces 446A】DZY Loves Sequences
    【Codeforces 429B】Working out
    【Codeforces 478C】Table Decorations
    【Codeforces 478C】Table Decorations
  • 原文地址:https://www.cnblogs.com/HolyK/p/13991949.html
Copyright © 2011-2022 走看看