zoukankan      html  css  js  c++  java
  • 多项式求逆学习笔记

    用法及推导

    这个主要用于在满足多项式(~A(x) * B(x) = C(x)~)且已知(~A(x), C(x)~)时来求多项式(~B(x)~)。可知(~B(x) = C(x) * A ^{-1}(x)~),其中(~A ^ {-1}(x))(~A(x)~)在模(~x ^ n~)意义下的逆元。
    考虑如何来求这个(~A^ {-1}(x)~).显然当(~n = 1~)时,(~A(x)~)中只有一个元素,根据费马小定理,(~A ^ {-1}(x)~)就是(~A(0) ^ {mod - 2}~),现在知道了这一点,我们再来推导,设(~B(x) = A ^ {-1}(x)~​),则有

    [A(x) * B(x) equiv 1 ~(mod ~x ^ n) ]

    而如果我们知道了(~A(x)~)在模(~x ^ {frac{n}{2}}~)意义下的逆元为(~B'(x)~),那么(B(x)~)在此意义下也成立,有

    [A(x) * B'(x) equiv 1 ~(mod~{x ^ {frac{n}{2}}}) ]

    [A(x) * B(x) equiv 1 ~(mod~{x ^ {frac{n}{2}}}) ]

    两式相减得

    [A(x) * (B(x) - B'(x)) equiv 0 ~(mod~{x ^ {frac{n}{2}}}) Leftrightarrow B(x) - B'(x) equiv 0 ~(mod~ x ^ {frac{n}{2}}) ]

    两边同乘平方,得

    [B^2(x) - (2B*B')(x) + B' ^ 2(x) equiv 0~(mod~ x ^ {frac{n}{2}}) ]

    再同乘个(~A(x)~),可以消掉所有的(~B(x)~),得

    [B(x) - 2B'(x) + (A * B'^2)(x)equiv 0 ~(mod~ x ^ {frac{n}{2}}) ]

    移项可得

    [B(x) equiv 2B'(x) - (A * B'^2)(x)~(mod~ x ^ {frac{n}{2}}) ]

    到这里式子就推完了。其实很好推也很好记。观察这个式子,可以发现(~B(x)~)是由(~A(x)~)(~B'(x)~)计算而来的。那么考虑递归,在回溯的时候往前代计算答案就行了,代码这样写。

    inline void Inv(int *a, int *b, int len) { // b is the inv of a
    
        if (len == 1) { b[0] = qpow(a[0], mod - 2); return; }
    
        Inv(a, b, len >> 1);
        
        bit = 0; for (siz = 1; siz <= len; siz <<= 1) ++ bit;
        For(i, 0, siz - 1) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    
        For(i, 0, len - 1) p[i] = a[i], q[i] = b[i];
        
        NTT(p, 1), NTT(q, 1);
        For(i, 0, siz - 1) p[i] = 1ll * q[i] * q[i] % mod * p[i] % mod;
        NTT(p, -1);
        
        For(i, 0, len - 1) b[i] = add(2 * b[i] % mod, mod - p[i]);
    }
    

    完整代码洛谷模板题

    #include<bits/stdc++.h>
    #define For(i, j, k) for (int i = j; i <= k; ++i)
    #define Forr(i, j, k) for (int i = j; i >= k; --i)
    using namespace std;
    
    inline int read() {
        int x = 0, p = 1; char c = getchar();
        while(!isdigit(c)) { if(c == '-') p = -1; c = getchar(); }
        while(isdigit(c)) x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
        return x *= p;
    }
    
    inline void File() {
        freopen("luogu4238.in", "r", stdin);
        freopen("luogu4238.out", "w", stdout);
    }
    
    const int N = 4e5 + 10, mod = 998244353;
    int powg[N], invg[N], a[N], b[N], rev[N], bit, siz, n;
    int p[N], q[N];
    
    inline int qpow(int a, int b) {
        static int res;
        for (res = 1; b; b >>= 1, a = 1ll * a * a % mod)
            if (b & 1) res = 1ll * res * a % mod;
        return res;
    }
    
    inline int add(int x, int y) { return (x += y) >= mod ? x -= mod : x; }
    
    inline void NTT(int *a, int flag) {
        For(i, 0, siz - 1) if (rev[i] > i) swap(a[rev[i]], a[i]);	
    
        for (int i = 2; i <= siz; i <<= 1) {
            int wn = flag > 0 ? powg[i] : invg[i];
    
            for (int j = 0; j < siz; j += i) {
                int w = 1;
                for (int k = 0; k < i >> 1; w = 1ll * w * wn % mod, ++ k) {
                    int x = a[j + k], y = 1ll * w * a[j + k + (i >> 1)] % mod;
                    a[k + j] = add(x, y), a[k + j + (i >> 1)] = add(x, mod - y);
                }
            }
        }
    
        if (flag == -1) {
            int g = qpow(siz, mod - 2);
            For(i, 0, siz - 1) a[i] = 1ll * a[i] * g % mod;
        }
    }
    
    inline void Inv(int *a, int *b, int len) {
    
        if (len == 1) { b[0] = qpow(a[0], mod - 2); return; }
    
        Inv(a, b, len >> 1);
        
        bit = 0; for (siz = 1; siz <= len; siz <<= 1) ++ bit;
        For(i, 0, siz - 1) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    
        For(i, 0, len - 1) p[i] = a[i], q[i] = b[i];
        
        NTT(p, 1), NTT(q, 1);
        For(i, 0, siz - 1) p[i] = 1ll * q[i] * q[i] % mod * p[i] % mod;
        NTT(p, -1);
        
        For(i, 0, len - 1) b[i] = add(2 * b[i] % mod, mod - p[i]);
    }
    
    int main() {
        n = read() - 1;
        For(i, 0, n) a[i] = read();
        
        for (siz = 1; siz <= n << 1; siz <<= 1) ++ bit;
    
        int g = qpow(3, mod - 2);
        for (int i = 1; i <= siz; i <<= 1) {
            powg[i] = qpow(3, (mod - 1) / i);
            invg[i] = qpow(g, (mod - 1) / i);
        }
    
        int len = siz >> 1;
        Inv(a, b, len);	
    
        For(i, 0, n) printf("%d ", b[i]);
    
        return 0;
    }
    
    
  • 相关阅读:
    线程应用示例
    Microsoft Visual Studio 2005 BETA2最新资源大杂烩
    135,139,445端口的关闭方法
    开源软件新时代 55个经典开源Windows工具
    图书商城项目总论
    无处不在的XML
    ADO.NET实例教学一
    递归
    手写代码生成器
    数据库的应用详解三
  • 原文地址:https://www.cnblogs.com/LSTete/p/9575570.html
Copyright © 2011-2022 走看看