zoukankan      html  css  js  c++  java
  • 快速构造FFT/NTT

    @(学习笔记)[FFT, NTT]

    问题概述

    给出两个次数为(n)的多项式(A)(B), 要求在(O(n log n))内求出它们的卷积, 即对于结果(C)的每一项, 都有$$c_i = sum_{j = 0}^{n}a_j cdot b_{i - j}$$

    问题求解

    大致思路

    • 朴素做法: 考虑按照上面的式子暴力运算, 时间复杂度: (O(n^2))
    • 考虑把多项式化作点值表达, 记$$A(x) =sum_{i = 0}^n a_i x^i$$ 我们把(A)(B)的点值表达乘起来, 得到的就是(C)的点值表达, 即$$A(x) cdot B(x) = C(x)$$
    • 我们把(x o A(x))的运算称作是DFT(离散傅立叶变换Discrete Fourier Transform)
    • 对于一个次数为(n)的多项式, 我们有它的(n)组不同点值表达, 通过点值表达求出原多项式的每一项的运算, 我们称之为IDFT(逆傅立叶变换)

    DFT

    考虑两个次数为(n)的多项式卷积, 得到的结果次数最高达到了(2n - 1). 所以我们至少需要(2n - 1)个结果的点值表达, 才足够把结果逆推出来(Hint: 为什么是(2n-1)个点值表达? 大体上可以从拉格朗日插值法来理解.).
    考虑如何化简运算.
    我们把多项式(A)拆分开奇数位和偶数位, 来计算它的点值表达. 我们令(x_k)为代入多项式计算的第(k)个值, 记$$f_0(x_k) = a_0 x_k^0 + a_2 x_k^1 + a_4 x_k^2 + ... + a_{2m} x_k^m$$

    [f_1(x_k) = a_1 x_k^1 + a_3 x_k^2 + a_5 x_k^3 + ... + a_{2m + 1} x_k^m ]

    则我们发现原多项式可以被表示作$$f(x_k) = f_0(x_k^2) + x_k cdot f_1(x_k^2)$$
    这样, 求原来长度为(len)的多项式的点值表达, 就变成求2个长度为(frac{len}{2})的多项式的点值表达.
    我们还注意到, 这里代入(f_0)(f_1)计算的值为(x_k^2). 假如我们代入的(x_i)(x_j)满足(x_i^2 = x_j^2)(x_i e x_j), 则只需要在(f_0)(f_1)中代入一个值进行运算, 再分别把(f_1)分别乘上(x_i)(x_j), 就可以一次处理出(f(x_i))(f(x_j))两个的结果. 这种优化手段就是FFT和NTT的基本思想.
    考虑如何构造(x_i^2 = x_j^2).
    这里我们以NTT为例. 在数论意义下, 根据费马小定理, 有$$g^{p - 1} equiv 1 mod p: p in 素数$$.
    当我们要代入(n)个值计算多项式的点值表达时, 令(x_0 = 1, x_1 = g^{frac{p - 1}{n}} ... x_k = g^{frac{p - 1}{n} cdot k}), 则有$$x_{k + frac{n}{2}}^2 = left( left(g^{frac{p - 1}{n}} ight)^{k + frac{n}{2}} ight)^2 = left( g^{frac{p - 1}{n} cdot k} ight)^2 cdot g^{p - 1} equiv left( g^{frac{p - 1}{n} cdot k} ight)^2 = x_k^2 mod p$$
    则每个(x_k)都可以与(x_{k + frac{n}{2}})分为一组, 一起计算.
    这样, 我们就可以在(O(n log n))内求出所需要的(n)个点值表达.

    IDFT

    我们把得到的点值表达看作是一个多项式, 再按照上面的DFT的做法搞一次, 得到这个点值表达的点值表达(大雾). 把每个点值表达都除以点值的个数, 即得到了(C)的每一项.
    不会证.
    结束.

    Code

    #include <cstdio>
    #include <cctype>
    #include <algorithm>
     
    const int N = (int)5e4, P = 998244353, G = 3;
     
    namespace Zeonfai
    {
        inline int getInt()
        {
            int sgn = 1, a = 0;
            char c;
             
            while(! isdigit(c = getchar()))
                if(c == '-')
                    sgn *= -1;
             
            while(isdigit(c))
                a = a * 10 + c - '0', c = getchar();
             
            return a * sgn;
        }
    }
     
    namespace convolution
    {
        const int DEG = N << 2;
        int deg, rev[DEG], omega[DEG], inv[DEG];
     
        inline int modPower(int a, int x)
        {
            int res = 1;
     
            for(; x; a = (long long)a * a % P, x >>= 1)
                if(x & 1)
                    res = (long long) res * a % P;
     
            return res;
        }
     
        inline void pretreat(int n, int m)
        {
            int sum = n + m;
            deg = 1;
            int bit = 0;
     
            for(; deg < sum; deg <<= 1, ++ bit);
     
            rev[0] = 0;
     
            for(int i = 1; i < deg; ++ i)
                rev[i] = rev[i >> 1] >> 1 | (i & 1) << bit - 1;
     
            for(int i = 0; 1 << i <= deg; ++ i)
                omega[i] = modPower(G, (P - 1) / (1 << i)), inv[i] = modPower(omega[i], P - 2);
        }
     
        inline void NTT(int *a, int opt)
        {
            for(int i = 0; i < deg; ++ i)
                if(rev[i] < i)
                    std::swap(a[i], a[rev[i]]);
     
            int cnt = 0;
     
            for(int i = 2; i <= deg; i <<= 1)
            {
                ++ cnt;
                int curOmega = ~ opt ? omega[cnt] : inv[cnt];
     
                for(int j = 0; j < deg; j += i)
                {
                    int omega = 1;
     
                    for(int k = j; k < j + i / 2; ++ k)
                    {
                        int u = a[k], t = (long long)omega * a[k + i / 2] % P;
                        a[k] = (u + t) % P, a[k + i / 2] = (u - t + P) % P;
                        omega = (long long)omega * curOmega % P;
                    }
                }
     
            }
     
            if(opt == -1)
            {
                int inv = modPower(deg, P - 2);
     
                for(int i = 0; i < deg; ++ i)
                    a[i] = (long long)a[i] * inv % P;
            }
        }
     
        inline void work(int *a, int n, int *b, int m)
        {
            pretreat(n, m);
            NTT(a, 1), NTT(b, 1);
     
            for(int i = 0; i < deg; ++ i)
                a[i] = (long long)a[i] * b[i] % P;
     
            NTT(a, -1);
     
            for(int i = 0; i <= n + m; ++ i)
                printf("%d ", a[i]);
        }
    }
     
    int main()
    {
        #ifndef ONLINE_JUDGE
        freopen("polynomial.in", "r", stdin);
        freopen("polynomial.out", "w", stdout);
        #endif
     
        using namespace Zeonfai;
        int n = getInt(), m = getInt(), tp = getInt();
        static int a[N << 2], b[N << 2];
         
        for(int i = 0; i <= n; ++ i)
            a[i] = getInt();
         
        for(int i = 0; i <= m; ++ i)
            b[i] = getInt();
     
        convolution::work(a, n, b, m);
    }
    
  • 相关阅读:
    汉字词组换行
    C#中获取Excel文件的第一个表名
    SQL查找某一条记录的方法
    C#数据库连接字符大全
    整理的asp.net资料!(不得不收藏)
    母版页的优点,及母版页与内容页中相互访问方法
    13范式
    使用 Jackson 树连接线形状
    word2007,取消显示回车符
    三张表之间相互的多对多关系
  • 原文地址:https://www.cnblogs.com/ZeonfaiHo/p/6790694.html
Copyright © 2011-2022 走看看