zoukankan      html  css  js  c++  java
  • 多项式学习笔记(三): 多项式全家桶

    1.多项式求逆

    给你 (A(x))(A(x)B(x) equiv 1 pmod {x^n}) 。 (模 (x^n) 是为了把高次项舍掉)

    假设我们已经得到了满足 (C(x)A(x) equiv 1 pmod {x^{nover 2}}) 的一个多项式 (C)

    那么由题意可得 (A(x)B(x)equiv 1 pmod {x^{nover 2}})

    两式联立可得:

    (B(x) equiv C(x) pmod {x^{nover 2}})

    (B(x) - C(x) equiv 0 pmod {x^{nover 2}})

    两边同时平方可得:

    (B^2(x) + C^2(x) - 2B(x)C(x) equiv 0 pmod {x^{n}})

    在同时乘上一个 (A(x)) 得:

    (A(x)B^2(x) + A(x)C^2(x)-2A(x)B(x)C(x)equiv 0 pmod {x^{n}})

    然后由题意可得 (A(x)B(x)equiv 1 pmod {x^n}) ,代入化简可得:

    (B(x) + A(x)C^2(x)-2C(x) equiv 0 pmod {x^n})

    (B(x) = 2C(x) - A(x)C^2(x))

    然后,我们每次都可以把项数减半递归求解, 如果项数为 (1) 的话结果显然是零次项的逆元。

    复杂度 (T(n) = T({nover 2}) + nlogn = nlogn)

    Code

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    using namespace std;
    #define int long long
    const int N = 1e6+10;
    const int p = 998244353;
    int n,a[N],b[N],rev[N],c[N];
    inline int read()
    {
        int s = 0,w = 1; char ch = getchar();
        while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
        while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
        return s * w;
    }
    int ksm(int a,int b)
    {
        int res = 1;
        for(; b; b >>= 1)
        {
            if(b & 1) res = res * a % p;
            a = a * a % p;
        }
        return res;
    }
    void NTT(int *a,int len,int opt)
    {
        for(int i = 0; i < len; i++)
        {
            if(i < rev[i]) swap(a[i],a[rev[i]]);
        }
        for(int h = 1; h < len; h <<= 1)
        {
            int wn = ksm(3,(p-1)/(h<<1));
            if(opt == -1) wn = ksm(wn,p-2);
            for(int j = 0; j < len; j += (h<<1))
            {
                int w = 1;
                for(int k = 0; k < h; k++)
                {
                    int u = a[j + k];
                    int v = w * a[j + h + k] % p;
                    a[j + k] = (u + v) % p;
                    a[j + h + k] = (u - v + p) % p;
                    w = w * wn % p;
                }
            }
        }
        if(opt == -1)
        {
            int inv = ksm(len,p-2);
            for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
        }
    }
    void Inv(int n,int *a,int *b)//求 A(x)B(x) = 1 mod x^n
    {
        if(n == 1)//项数为1的情况
        {
            b[0] = ksm(a[0],p-2);
            return;
        }
        Inv((n+1)>>1,a,b);//递归求 C(x)
        int lim = 1, tim = 0;
        while(lim < (n<<1)) lim <<= 1, tim++;
        for(int i = 0; i < lim; i++)//预处理NTT的反转数组
        {
            rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
        }
        //注意,不能用 a 来做多项式乘法,因为如果拿 a 做了多项式乘法,那么 a 的值在递归过程中,就会发生改变。
        for(int i = 0; i < n; i++) c[i] = a[i];//把 a 赋给 c,用 c 来做多项式乘法
        for(int i = n; i < lim; i++) c[i] = 0;//多余的高次项舍去
        //此时的 B 数组存的是 B(x)A(x) = 1 mod x^{n/2},C数组存的是 A(x)
        NTT(c,lim,1); NTT(b,lim,1);//求 B 和 C 的点值表示法
        for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;//计算 B的点值
        NTT(b,lim,-1);//把B转化为系数表示法
        for(int i = n; i < lim; i++) b[i] = 0;//高次项舍去 
    }
    signed main()
    {
        n = read();
        for(int i = 0; i < n; i++) a[i] = read();
        Inv(n,a,b);
        for(int i = 0; i < n; i++) printf("%lld ",b[i]);
        printf("
    ");
        return 0;
    }
    

    2.多项式开根

    (B^2(x) equiv A(x) pmod {x^n})

    假设,我们得到了满足 (C^2(x) equiv A(x) pmod {x^{nover 2}}) 的一个多项式 (C(x))

    又因为 (B^2(x) equiv A(x) pmod {x^{nover 2}})

    两式联立可得:

    (B^2(x) equiv C^2(x) pmod {x^{nover 2}})

    (B^2(x)-C^2(x) equiv 0 pmod {x^{nover 2}})

    两边同时平方可得:

    (B^4(x) + C^4(x) - 2B^2(x)C^2(x) equiv 0 pmod {x^n})

    两边同时加上 (4B^2(x)C^2(x)) 可得:

    (B^4(x) + C^4(x) + 2B^2(x)C^2(x) equiv 4B^2(x)C^2(x) pmod {x^n})

    ((B^2(x) + C^2(x))^2 equiv 4B^2(x)C^2(x) pmod {x^n})

    把右边的 (4C^2(x)) 除过去可得:

    ({(B^2(x) + C^2(x))^2 over 4C^2(x)}equiv B^2(x)pmod {x^n})

    (B(x) equiv {B^2(x) + C^2(x)over 2C(x)} pmod {x^n})

    又因为 (B^2(x) equiv A(x) pmod {x^n}) ,代入可得:

    (B(x) equiv {A(x) + C^2(x)over 2C(x)} pmod {x^n})

    还是像求逆一样每次项数减半,递归求解,当项数为 (1) 的时候答案为 (sqrt {常数项})

    多项式求逆加NTT即可。

    复杂度 (O(nlogn))

    Code(常数爆炸):

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    using namespace std;
    #define int long long
    const int N = 1e6+10;
    const int p = 998244353;
    int n,a[N],b[N],c[N],d[N],rev[N];
    inline int read()
    {
        int s = 0,w = 1; char ch = getchar();
        while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
        while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
        return s * w;
    }
    int ksm(int a,int b)
    {
        int res = 1;
        for(; b; b >>= 1)
        {
            if(b & 1) res = res * a % p;
            a = a * a % p;
        }
        return res;
    }
    void NTT(int *a,int len,int opt)//NTT 板子
    {
        for(int i = 0; i < len; i++)
        {
            if(i < rev[i]) swap(a[i],a[rev[i]]);
        }
        for(int h = 1; h < len; h <<= 1)
        {
            int wn = ksm(3,(p-1)/(h<<1));
            if(opt == -1) wn = ksm(wn,p-2);
            for(int j = 0; j < len; j += (h<<1))
            {
                int w = 1;
                for(int k = 0; k < h; k++)
                {
                    int u = a[j + k];
                    int v = w * a[j + h + k] % p;
                    a[j + k] = (u + v) % p;
                    a[j + h + k] = (u - v + p) % p;
                    w = w * wn % p;
                }
            }
        }
        if(opt == -1)
        {
            int inv = ksm(len,p-2);
            for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
        }
    }
    void Inv(int n,int *a,int *b)//多项式求逆板子
    {
        if(n == 1)
        {
            b[0] = ksm(a[0],p-2);
            return;
        }
        Inv((n+1)>>1,a,b);
        int lim = 1, tim = 0;
        while(lim < (n<<1)) lim <<= 1, tim++;
        for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
        for(int i = 0; i < n; i++) c[i] = a[i];
        for(int i = n; i < lim; i++) c[i] = 0;
        NTT(c,lim,1); NTT(b,lim,1);
        for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
        NTT(b,lim,-1);
        for(int i = n; i < lim; i++) b[i] = 0;//记得清空
    }
    void sqrt(int n,int *a,int *b)
    {
        if(n == 1)//项数为 1的情况
        {
            b[0] = (int) sqrt(a[0]);
            return;
        } 
        sqrt((n+1)>>1,a,b);    
    	Inv(n,b,d);//这里求 mod x^n 下的逆元,而不是 mod x^lim 下的逆元 
        int lim = 1, tim = 0;
        while(lim < (n<<1)) lim <<= 1, tim++;
        for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
        for(int i = 0; i < n; i++) c[i] = a[i];//用c数组代替a来做多项式乘法
        for(int i = n; i < lim; i++) c[i] = 0;
        //这里 b 数组存的是 C^2(x) = A(x) mod x^{n/2}
        // c数组 存的是 A(x), d数组存的是 C(x) 的乘法逆
        NTT(b,lim,1); NTT(c,lim,1); NTT(d,lim,1);
        int inv2 = ksm(2,p-2);
        for(int i = 0; i < lim; i++) b[i] = (b[i] * b[i] % p + c[i] % p) * d[i] % p * inv2 % p;//根据柿子算出 B(x) 的点值
        NTT(b,lim,-1);//转换为系数表示法
        for(int i = n; i < lim; i++) b[i] = 0;   
        for(int i = 0; i < lim; i++) d[i] = 0;//多次调用要清空
    } 
    signed main()
    {
        n = read();
        for(int i = 0; i < n; i++) a[i] = read();
        sqrt(n,a,b);
        for(int i = 0; i < n; i++) printf("%lld ",b[i]);
        return 0;
    }
    

    3.多项式求导

    (A(x) = displaystylesum_{i=0}^{n} a_ix^i) , 则 (A^prime(x) = displaystylesum_{i=0}^{n} ia_{i}x^{i-1})

    void qiudao(int len,int *a,int *b)
    {
        for(int i = 1; i < len; i++) b[i-1] = i * a[i] % p;
        b[len-1] = 0;
    }
    

    5.多项式积分

    (A(x) = displaystylesum_{i=0}^{n}a_ix^i) ,则 (int A(x) = displaystylesum_{i=1}^{n} {a_iover i+1} x^{i+1})

    void jifen(int len,int *a,int *b)
    {
        for(int i = 1; i < len; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
        b[0] = 0;
    }
    

    6.多项式 ln

    (B(x) equiv lnA(x) pmod {x^n})

    (F(x) = lnA(x)) ,则 对等式两边同时求导可得:

    (B^prime(x) equiv F^prime(x) pmod {x^n})

    根据复合函数求导公式 (f^prime(g(x)) = f^prime(g(x)) g^prime(x)) 可得:

    (B^prime(x) equiv {A^prime (x)over A(x)} pmod {x^n})

    先求出 (A(x)) 的导函数和乘法逆,在相乘得到 (B^prime(x)) ,最后在积分回去即可。

    多项式求逆,多项式求导,多项式积分,多项式乘法。

    复杂度 (O(nlogn))

    code

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    using namespace std;
    #define int long long
    const int p = 998244353;
    const int N = 1e6+10;
    int n,a[N],b[N],c[N],rev[N],A[N],B[N];
    inline int read()
    {
        int s = 0,w = 1; char ch = getchar();
        while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
        while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
        return s * w;
    }
    int ksm(int a,int b)
    {
        int res = 1;
        for(; b; b >>= 1)
        {
            if(b & 1) res = res * a % p;
            a = a * a % p;
        }
        return res;
    }
    void NTT(int *a,int len,int opt)
    {
        for(int i = 0; i < len; i++)
        {
            if(i < rev[i]) swap(a[i],a[rev[i]]);
        }
        for(int h = 1; h < len; h <<= 1)
        {
            int wn = ksm(3,(p-1)/(h<<1));
            if(opt == -1) wn = ksm(wn,p-2);
            for(int j = 0; j < len; j += (h<<1))
            {
                int w = 1;
                for(int k = 0; k < h; k++)
                {
                    int u = a[j + k];
                    int v = w * a[j + h + k] % p;
                    a[j + k] = (u + v) % p;
                    a[j + h + k] = (u - v + p) % p;
                    w = w * wn % p;
                }
            }
        }
        if(opt == -1)
        {
            int inv = ksm(len,p-2);
            for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
        }
    }
    void Inv(int n,int *a,int *b)
    {
        if(n == 1)
        {
            b[0] = ksm(a[0],p-2);
            return;
        }
        Inv((n+1)>>1,a,b);
        int lim = 1, tim = 0;
        while(lim < (n<<1)) lim <<= 1, tim++;
        for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
        for(int i = 0; i < n; i++) c[i] = a[i];
        for(int i = n; i < lim; i++) c[i] = 0;
        NTT(b,lim,1); NTT(c,lim,1);
        for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
        NTT(b,lim,-1);
        for(int i = n; i < lim; i++) b[i] = 0;
    }
    void qiudao(int len,int *a,int *b)
    {
        for(int i = 1; i < len; i++) b[i-1] = i * a[i] % p;
        b[len-1] = 0;
    }
    void jifen(int len,int *a,int *b)
    {
        for(int i = 1; i < len; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
        b[0] = 0;
    }
    void Ln(int n,int *a,int *b)
    {
        Inv(n,a,A); qiudao(n,a,B);//A 存的是 a的乘法逆,B存的是 a的导函数
        int lim = 1, tim = 0;
        while(lim < (n<<1)) lim <<= 1, tim++;
        for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
        NTT(A,lim,1); NTT(B,lim,1);
        for(int i = 0; i < lim; i++) B[i] = B[i] * A[i] % p;
        NTT(B,lim,-1); jifen(lim,B,b);//B存的是 b 的导函数
        for(int i = n; i < lim; i++) b[i] = 0;
    }
    signed main()
    {
        n = read();
        for(int i = 0; i < n; i++) a[i] = read();
        Ln(n,a,b);
        for(int i = 0; i < n; i++) printf("%lld ",b[i]);
        return 0;
    }
    

    7.多项式除法

    给你一个 (n) 次多项式 (A(x)) 和一个 (m) 次的多项式 (B(x)),求多项式 (C(x))(D(x)) 满足:

    1. (C(x)) 的次数为 (n-m), (D(x)) 的次数小于 (m)
    2. (A(x) = C(x) * B(x) + D(x))

    (f(x)) 是一个 (n) 次多项式,则定义 (inv(f(x)) = x^nf({1over x}))

    (inv(f(x)) = x^n f({1over x}) = x^n(a_0+a_1x^{-1}+...a_nx^{-n}) = a_{n} + a_{n-1}x^1 + a_{n-2}x^2+....a_{1}x^{n-1} + a_0x^{n})

    所以 (inv(f(x))) 其实就是把 (f(x)) 的系数反转过来得到的结果。

    (ecause A(x) = C(x) * B(x) + D(x))

    所以有 (inv(A(x)) = inv(C(x) * B(x) + D(x)))

    展开可得:

    (x^nA({1over x}) = x^{n} (C({1over x}) * B({1over x}) + D({1over x})))

    (x^nA({1over x}) = x^mB({1over x}) x^{n-m} C({1over x}) + x^{n-m+1} x^{m-1} D({1over x}))

    在转化为 (inv(f(X))) 可得:

    (inv(A(x)) = inv(B(x))inv(C(x)) + x^{n-m+1}inv(D(x)))

    两边同时模上 (x^{n-m+1}) 可得:

    (invA(x) equiv inv(B(x))inv(C(x)) pmod {x^{n-m+1}})

    (inv(C(x)) equiv {inv(A(x))over invB(x)} pmod {x^{n-m+1}})

    多项式乘法和多项式求逆可以求出来 (inv(C(x))), 在把系数反转得到 (C(x)).

    最后把 (C(x)) 代入原式可得到 (D(x)).

    复杂度 (O(nlogn))

    一定要注意清空数组(我这个沙比就因为这个卡在了50分好几回)

    Code:

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    using namespace std;
    #define int long long
    const int p = 998244353;
    const int N = 1e6+10;
    int n,m,rev[N],a[N],b[N],c[N],d[N],A[N],B[N],invB[N];
    inline int read()
    {
        int s = 0,w = 1; char ch = getchar();
        while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
        while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
        return s * w;
    }
    int ksm(int a,int b)
    {
        int res = 1;
        for(; b; b >>= 1)
        {
            if(b & 1) res = res * a % p;
            a = a * a % p;
        }
        return res;
    }
    void NTT(int *a,int len,int opt)
    {
        for(int i = 0; i < len; i++)
        {
            if(i < rev[i]) swap(a[i],a[rev[i]]);
        }
        for(int h = 1; h < len; h <<= 1)
        {
            int wn = ksm(3,(p-1)/(h<<1));
            if(opt == -1) wn = ksm(wn,p-2);
            for(int j = 0; j < len; j += (h<<1))
            {
                int w = 1;
                for(int k = 0; k < h; k++)
                {
                    int u = a[j + k];
                    int v = w * a[j + h + k] % p;
                    a[j + k] = (u + v) % p;
                    a[j + h + k] = (u - v + p) % p;
                    w = w * wn % p;
                }
            }
        }
        if(opt == -1)
        {
            int inv = ksm(len,p-2);
            for(int i = 0; i < len; i++) a[i] = (a[i] * inv % p + p) % p;
        }
    }
    void Inv(int n,int *a,int *b)
    {
        if(n == 1)
        {
            b[0] = ksm(a[0],p-2);
            return;
        }
        Inv((n+1)>>1,a,b);
        int lim = 1, tim = 0;
        while(lim < (n<<1)) lim <<= 1, tim++;
        for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
        for(int i = 0; i < n; i++) c[i] = a[i];
        for(int i = n; i < lim; i++) c[i] = 0;
        NTT(c,lim,1); NTT(b,lim,1);
        for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
        NTT(b,lim,-1); 
        for(int i = n; i < lim; i++) b[i] = 0;
    }
    void mul(int n,int m,int *a,int *b)
    {
    	int lim = 1, tim = 0;
    	while(lim < (n<<1)) lim <<= 1, tim++;
    	for(int i = 0; i <lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    	NTT(a,lim,1); NTT(b,lim,1);
    	for(int i = 0; i < lim; i++) a[i] = a[i] * b[i] % p;
    	NTT(a,lim,-1);
    	for(int i = n; i < lim; i++) a[i] = 0; 
    }
    void Chu(int n,int m,int *a,int *b)
    {
        for(int i = 0; i < n; i++) A[i] = a[n-i-1];//A 数组存的是 inv(A(x))
        for(int i = 0; i < m; i++) B[i] = b[m-i-1];//B 数组存的是 inv(B(x))
        Inv(n-m+1,B,invB); 
        for(int i = n-m+1; i < (n<<2); i++) A[i] = invB[i] = 0;
        mul(n-m+1,n-m+1,A,invB); 
        for(int i = 0; i < n-m+1; i++) c[i] = (A[n-m-i] % p + p) % p;
    	for(int i = 0; i < n-m+1; i++) printf("%lld ",c[i]); 
        printf("
    ");
        for(int i = n-m+1; i < (n<<2); i++) c[i] = 0;
        mul(n,n,c,b);
        for(int i = 0; i < m-1; i++) d[i] = ((a[i] - c[i]) % p + p) % p;
        for(int i = 0; i < m-1; i++) printf("%lld ",d[i]);
    } 
    signed main()
    {
        n = read() + 1; m = read() + 1;
        for(int i = 0; i < n; i++) a[i] = read();
        for(int i = 0; i < m; i++) b[i] = read();
        Chu(n,m,a,b);
        return 0;
    }
    
  • 相关阅读:
    LeetCode 252. Meeting Rooms
    LeetCode 161. One Edit Distance
    LeetCode 156. Binary Tree Upside Down
    LeetCode 173. Binary Search Tree Iterator
    LeetCode 285. Inorder Successor in BST
    LeetCode 305. Number of Islands II
    LeetCode 272. Closest Binary Search Tree Value II
    LeetCode 270. Closest Binary Search Tree Value
    LeetCode 329. Longest Increasing Path in a Matrix
    LintCode Subtree
  • 原文地址:https://www.cnblogs.com/genshy/p/14260473.html
Copyright © 2011-2022 走看看