zoukankan      html  css  js  c++  java
  • 多项式总结(持续更新)

    多项式的一堆乱七八糟的操作学了一部分了……(多点求值和快速插值还没有)

    打算写下来整理一下。不过因为还有一些没学的以及没完全理解的……只好先持续更新了。

    不扯淡了,直接开始。

    1.NTT
    FFT咱就不说了,有兴趣可以看兔哥博客.
    NTT和FFT很相似。但是因为FFT涉及到复数运算所以会有一些精度误差,然后有的时候也会遇到需要取模的情况……于是快速数论变换NTT应运而生。因为单位根和原根有相似的性质,所以NTT使用原根取代了单位根进行运算。模数998244353的原根是3,每次取原根的(frac{mod-1}{2i})次方代替单位根即可。逆变换的时候就用原根的逆元。

    注意最后我们要像FFT一样乘以数组长度的逆元。其实还有另一种办法,就是直接进行一次reverse。据pinkrabbit大佬说,NTT数列是共轭对称的(然鹅我不知道啥是共轭对称),不过结果是正确的。

    写法啥的和FFT都一样。看一下代码。有兴趣还可以看Miskcoo的博客

    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    
    using namespace std;
    typedef long long ll;
    const int M = 400005;
    const int mod = 998244353;
    const int G = 3;
    const int invG = 332748118;
    
    int read()
    {
       int ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    int n,m,a[M],b[M],c[M],rev[M],len = 1,L;
    
    int add(int a,int b) {return a + b > mod ? a + b - mod : a + b;}
    int mul(int a,int b) {return 1ll * a * b % mod;}
    
    int qpow(int a,int b)
    {
       int p = 1;
       while(b)
       {
          if(b & 1) p = mul(p,a);
          a = mul(a,a),b >>= 1;
       }
       return p;
    }
    
    void NTT(int *a,int n,int f)
    {
       rep(i,0,n) if(i < rev[i]) swap(a[i],a[rev[i]]);
       for(int i = 1;i < n;i <<= 1)
       {
          int w1 = qpow(f ? G : invG,(mod-1) / (i<<1));
          for(int j = 0;j < n;j += (i<<1))
          {
    	 int w = 1;
    	 rep(k,0,i-1)
    	 {
    	    int kx = a[k+j],ky = mul(a[k+j+i],w);
    	    a[k+j] = add(kx,ky),a[k+j+i] = add(kx,mod-ky),w = mul(w,w1);
    	 }
          }
       }
       if(!f)
       {
          int inv = qpow(n,mod-2);
          rep(i,0,n) a[i] = mul(a[i],inv);
       }
    }
    
    int main()
    {
       n = read(),m = read();
       rep(i,0,n) a[i] = read();
       rep(i,0,m) b[i] = read();
       while(len <= n+m+2) len <<= 1,L++;
       rep(i,0,len) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
       NTT(a,len,1),NTT(b,len,1);
       rep(i,0,len) c[i] = mul(a[i],b[i]);
       NTT(c,len,0);
       rep(i,0,n+m) printf("%d ",c[i]);enter;
       return 0;
    }
    

    2.多项式求逆
    这玩意在很多奇怪的多项式操作中都要用。
    具体就是给你一个多项式(F(x)),让你求出它在(mod x^n)意义下的逆元,也就是求出多项式(G(x)),使得(F(x)G(x) equiv 1 (mod x^n)).系数对于998244353取模。

    个人认为其核心思想是递归。对于只有一项的,那么显然(G(x))的常数项就是(F(x))的常数项的逆元,否则对于n项的,我们可以递归求解。
    首先假设我们已经知道了(F(x)H(x)equiv 1(mod x^frac{n}{2}))
    那么显然也有(F(x)G(x) equiv 1(mod x^frac{n}{2})) 这个是根据(G(x))的定义来的。
    之后我们把两个式子相减,就可以得到:(F(x)(G(x)-H(x)) equiv 0(mod x^frac{n}{2}))
    自然就有((G(x)-H(x)) equiv 0(mod x^frac{n}{2}))
    把这个式子进行平方,我们就可以得到:((G(x)-H(x))^2 equiv 0 (mod x^n)).这里解释一下,因为一个在(mod x^frac{n}{2})情况下为0的多项式,指数小于(frac{n}{2})的项都为0,因为卷积的性质,其自乘的前n项也必然都为0,所以其在(mod x^n)的意义下也是0.
    式子展开就可以得到:(G(x)^2 + H(x)^2 - 2G(x)H(x) equiv 0(mod x^n))
    移项,同时乘以(F(x)),由(F(x)G(x) equiv 1(mod x^n))就可以得到(G(x) equiv 2H(x)-F(x)H(x)^2(mod x^n))
    我们就可以用NTT来解决啦。时间复杂度(O(nlogn))

    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    #define lowbit(x) (x & (-x))
    
    using namespace std;
    typedef long long ll;
    const int M = 400005;
    const ll mod = 998244353;
    const ll G = 3;
    const ll invG = 332748118;
    
    ll read()
    {
       ll ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    ll n,g[M],f[M],r[M],c[M],b[M],rev[M];
    
    ll inc(ll a,ll b) {return (a + b) % mod;}
    ll mul(ll a,ll b) {return a * b % mod;}
    
    ll qpow(ll a,ll b)
    {
       ll p = 1;
       while(b)
       {
          if(b & 1) p = mul(p,a);
          a = mul(a,a),b >>= 1;
       }
       return p;
    }
    
    void NTT(ll *a,ll l,ll f)
    {
       rep(i,0,l-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
       for(int i = 1;i < l;i <<= 1)
       {
          ll w1 = qpow(f ? G : invG,(mod-1) / (i<<1));
          for(int j = 0;j < l;j += (i<<1))
          {
         ll w = 1;
         rep(k,0,i-1)
         {
            ll kx = a[k+j],ky = mul(a[k+j+i],w);
            a[k+j] = inc(kx,ky),a[k+j+i] = inc(kx,mod-ky);
            w = mul(w,w1);
         }
          }
       }
       if(!f)
       {
          ll inv = qpow(l,mod-2);
          rep(i,0,l-1) a[i] = mul(a[i],inv);
       }
    }
    
    void solve(int len)
    {
       if(len == 1) {g[0] = qpow(f[0],mod-2);return;}
       solve((len+1)>>1);
       ll l = 1,L = 0;
       while(l < (len<<1)) l <<= 1,L++;
       rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
       rep(i,0,len-1) c[i] = f[i];
       rep(i,len,l-1) c[i] = 0;
       NTT(c,l,1),NTT(g,l,1);
       rep(i,0,l-1) g[i] = mul(inc(2,mod-mul(c[i],g[i])),g[i]);
       NTT(g,l,0);
       rep(i,len,l-1) g[i] = 0;
    }
    
    int main()
    {
       n = read();
       rep(i,0,n-1) f[i] = read();
       solve(n);
       rep(i,0,n-1) printf("%lld ",g[i]);enter;
       return 0;
    }
    
    

    3.多项式对数函数
    多项式的对数函数是啥?(Misckoo)大佬说,你可以将其理解为多项式和麦克劳林级数的复合。(然鹅我不会高数啊2333)或许就是把(ln(1-x))进行一下泰勒展开?
    不说这些我不大会的了,反正你计算的时候其实不用。我们要求的就是多项式(B(x)) ,使得(B(x)equiv ln A(x) (mod x^n))
    这个直接算非常难。考虑同时对两边求导。就有(B'(x) equiv frac{A'(x)}{A(x)} (mod x^n))
    直接对(A(x))求导求逆,之后对(B(x))积分一下就行。
    求导和积分都是(O(n))的,所以复杂度就是求逆的复杂度。

    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    #define space putchar(' ')
    #define lowbit(x) (x & (-x))
    
    using namespace std;
    typedef long long ll;
    const int M = 400005;
    const ll mod = 998244353;
    const ll G = 3;
    const ll invG = 332748118;
    
    ll read()
    {
       ll ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    void write(ll x)
    {
       if(x < 10) {putchar(x+'0');return;}
       char k = x % 10 + '0';
       write(x / 10),putchar(k);
    }
    
    ll n,g[M],f[M],r[M],c[M],rev[M],h[M];
    
    ll inc(ll a,ll b) {return (a + b) % mod;}
    ll mul(ll a,ll b) {return 1ll * (a) * (b) % mod;}
    
    ll qpow(ll a,ll b)
    {
       ll p = 1;
       while(b)
       {
          if(b & 1) p = mul(p,a);
          a = mul(a,a),b >>= 1;
       }
       return p;
    }
    
    void NTT(ll *a,ll l,ll f)
    {
       rep(i,0,l-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
       for(int i = 1;i < l;i <<= 1)
       {
          ll w1 = qpow(f ? G : invG,(mod - 1) / (i << 1));
          for(int j = 0;j < l;j += (i<<1))
          {
    	 ll w = 1;
    	 rep(k,0,i-1)
    	 {
    	    ll kx = a[k+j],ky = mul(w,a[k+j+i]);
    	    a[k+j] = inc(kx,ky),a[k+j+i] = inc(kx,mod-ky);
    	    w = mul(w,w1);
    	 }
          }
       }
       if(!f)
       {
          ll inv = qpow(l,mod-2);
          rep(i,0,l-1) a[i] = mul(a[i],inv);
       }
    }
    
    void derit(ll *a,ll *b,ll len) {rep(i,1,len-1) b[i-1] = mul(a[i],i);b[len-1] = 0;}
    
    void inter(ll *a,ll *b,ll len) {rep(i,1,len-1) b[i] = mul(a[i-1],qpow(i,mod-2));b[0] = 0;}
    
    void getinv(ll *a,ll *b,ll len)
    {
       if(len == 1) {b[0] = qpow(a[0],mod-2);return;}
       getinv(a,b,(len+1)>>1);
       ll l = 1,L = 0;
       while(l < (len << 1)) l <<= 1,L++;
       rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
       rep(i,0,len-1) g[i] = a[i];
       rep(i,len,l-1) g[i] = 0;
       NTT(g,l,1),NTT(b,l,1);
       rep(i,0,l-1) b[i] = mul(inc(2,mod-mul(g[i],b[i])),b[i]);
       NTT(b,l,0);
       rep(i,len,l-1) b[i] = 0; 
    }
    
    void getln(ll *a,ll *b,ll len)
    {
       derit(a,c,len),getinv(a,r,len);
       ll l = 1,L = 0;
       while(l < (len<<1)) l <<= 1,L++;
       rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
       NTT(c,l,1),NTT(r,l,1);
       rep(i,0,l-1) c[i] = mul(c[i],r[i]);
       NTT(c,l,0);
       inter(c,b,len);
    }
    
    int main()
    {
       n = read();
       rep(i,0,n-1) f[i] = read();
       ll l = 1;
       while(l < n) l <<= 1;
       getln(f,h,l);
       rep(i,0,n-1) write(h[i]),space;enter;
       return 0;
    }
    

    4.多项式指数函数
    理解方法的话……就照着对数函数理解就行。其实就是让你求(B(x)),满足(B(x) equiv e^{A(x)} (mod x^n))
    这个咋做?继续求导?(e^x)的导数还是(e^x)……
    于是我们需要一个前置知识:多项式牛顿迭代。

    具体可以参考Misckoo的博客 ,这里只给出式子了。
    假设我们知道(G(x)),我们想求一个多项式(F(x)),满足(G(F(x)) equiv 0 (mod x^n))
    首先只有一项的时候,(G(F(x)) equiv 0(mod x))是要单独计算的。
    还是倍增的思想……假设我们已经求出(G(H(x)) equiv 0 (mod x^frac{n}{2}))
    如何拓展到(mod x^n)下呢? 我们把(G(F(x)))(H(x))处进行泰勒展开。就有:
    (G(F(x)) = G(H(x)) + frac{G'(H(x))}{1!}(F(x)-H(x)) + frac{G''(H(x))}{2!}(F(x)-H(x))^2 + …)
    因为(F(x))(H(x))的后面(frac{n}{2})项相同,故((F(x)-H(x))^2)及以上次方项在(mod x^n)意义下均为0,所以就有:
    (G(F(x)) equiv G(H(x)) + G'(H(x))(F(x)-H(x)) (mod x^n))
    (G(F(x)) equiv 0 (mod x^n)),所以(F(x) equiv H(x) - frac{G(H(x))}{G'(H(x)} (mod x^n))
    我们就可以开始使用这个式子解决问题了。

    回到刚才的问题。我们把式子变一下形再移项,就是(ln B(x) - A(x) equiv 0 (mod x^n))
    我们相当于求函数零点。令(G(B(x)) = ln B(x)-A(x)),对函数求导得到:(G'(B(x)) = frac{1}{B(x)})
    因为这里(F(x))是变量,自然(A(x))可以看作常数。
    带回上面的多项式牛顿迭代的式子,就有:(B(x) equiv B_0(x)(1-ln B_0(x) + A(x)) (mod x^n))
    用多项式对数函数+递归即可解决。

    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    #define I inline 
    
    using namespace std;
    typedef long long ll;
    const int M = 400005; 
    const int INF = 200000000;
    const int mod = 998244353;
    const int T = 3;
    const int invT = 332748118;
    
    int read()
    {
       int ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    int n,F[M],G[M],Fi[M],A[M],H[M],rev[M],c[M],d[M],Gi[M];
    
    int add(int a,int b){return (a+b) % mod;}
    int mul(int a,int b){return 1ll * a * b % mod;}
    
    int qpow(int a,int b)
    {
       int p = 1;
       while(b)
       {
          if(b&1) p = mul(p,a);
          a = mul(a,a),b >>= 1;
       }
       return p;
    }
    
    void NTT(int *a,int n,int f)
    {
       rep(i,0,n-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
       for(int i = 1;i < n;i <<= 1)
       {
          int w1 = qpow(f? T : invT,(mod-1) / (i<<1));
          for(int j = 0;j < n;j += (i<<1))
          {
             int w = 1;
             rep(k,0,i-1)
             {
                int kx = a[k+j],ky = mul(w,a[k+j+i]);
                a[k+j] = add(kx,ky),a[k+j+i] = add(kx,mod-ky),w = mul(w,w1);
             }
          }
       }
       if(!f)
       {
          int inv = qpow(n,mod-2);
          rep(i,0,n-1) a[i] = mul(a[i],inv);
       }
    }
    
    
    void derit(int *a,int *b,int len) {rep(i,1,len-1) b[i-1] = mul(a[i],i);b[len-1] = 0;}
    void inter(int *a,int *b,int len) {rep(i,1,len-1) b[i] = mul(a[i-1],qpow(i,mod-2));b[0] = 0;}
    
    void getrev(int l,int L){rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));}
    
    void getinv(int *a,int *b,int len)
    {
       if(len == 1) {b[0] = qpow(a[0],mod-2);return;}
       getinv(a,b,(len+1)>>1);
       int l = 1,L = 0;
       while(l < (len<<1)) l <<= 1,L++;
       getrev(l,L);
       rep(i,0,len-1) Gi[i] = a[i];
       rep(i,len,l-1) Gi[i] = 0;
       NTT(Gi,l,1),NTT(b,l,1);
       rep(i,0,l-1) b[i] = mul(add(2,mod-mul(Gi[i],b[i])),b[i]);
       NTT(b,l,0);
       rep(i,len,l-1) b[i] = 0;
    }
    
    void getln(int *a,int *b,int len)
    {
       derit(a,c,len),getinv(a,G,len);
       int l = 1,L = 0;
       while(l < (len<<1)) l <<= 1,L++;
       getrev(l,L);
       NTT(c,l,1),NTT(G,l,1);
       rep(i,0,l-1) c[i] = mul(G[i],c[i]);
       NTT(c,l,0);
       inter(c,b,len);
       rep(i,0,l-1) c[i] = G[i] = 0;
    }
    
    void getexp(int *a,int *b,int len)
    {
       if(len == 1) {b[0] = 1;return;}
       getexp(a,b,(len+1)>>1),getln(b,F,len);
       F[0] = add(a[0]+1,mod-F[0]);
       rep(i,1,len-1) F[i] = add(a[i],mod-F[i]);
       int l = 1,L = 0;
       while(l < (len<<1)) l <<= 1,L++;
       getrev(l,L);
       NTT(F,l,1),NTT(b,l,1);
       rep(i,0,l-1) b[i] = mul(b[i],F[i]);
       NTT(b,l,0);
       rep(i,len,l-1) b[i] = F[i] = 0;
    }
    
    int main()
    {
       n = read();
       rep(i,0,n-1) A[i] = read();
       int l = 1;
       while(l <= n) l <<= 1;
       getexp(A,H,l);
       rep(i,0,n-1) printf("%d ",H[i]);enter;
       return 0;
    }
    
    

    5.多项式开根
    给定(F(x)),求(G(x)^2 equiv F(x) (mod x^n))
    这个有几种做法……首先说纯代数推导吧。
    同样用倍增的想法,假设我们知道(H(x)^2 equiv F(x) (mod x^frac{n}{2}))
    将两个式子相减得到(G(x)^2 - H(x)^2 equiv 0 (mod x^frac{n}{2}))
    则有(G(x) - H(x) equiv 0 (mod x^frac{n}{2}))
    将这个式子平方之后展开,用(F(x))来替换(G(x)^2),于是有了:(F(x) - 2G(x)H(x) + H(x)^2 equiv 0 (mod x^n))
    移项得到(G(x) equiv frac{H(x)^2 + F(x)}{2H(x)})
    于是我们就可以用求逆来解决这个问题了。注意你可以选择先约分再计算,或者直接计算,结果都是一样的,不过要特别注意的是,在做加法的时候不要加多了,加到len即可。

    这种做法还有另一种推导方式,就是多项式牛顿迭代。
    我们要求的式子都一样,那我们可以构造函数:(H(G(x)) = G(x)^2 - F(x)),求这个函数的零点。
    对之求导,(H'(G(x)) = 2G(x)),之后带入上面的多项式牛顿迭代的方程,立即得到:(G(x) equiv frac{G_0(x)^2 + F(x)}{2G_0(x)} (mod x^n))
    用与上面相同的方法解决即可。

    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    #define I inline
    
    using namespace std;
    typedef long long ll;
    const int M = 800005;
    const int mod = 998244353;
    const int G = 3;
    const int invG = 332748118;
    
    int read()
    {
       int ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    int rev[M],A[M],B[M],C[M],D[M],F[M],n,inv2;
    
    int inc(int a,int b){return (a+b) % mod;}
    int mul(int a,int b){return 1ll * a * b % mod;}
    
    int qpow(int a,int b)
    {
       int p = 1;
       while(b)
       {
          if(b & 1) p = mul(p,a);
          a = mul(a,a),b >>= 1;
       }
       return p;
    }
    
    void getrev(int l,int L) {rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));}
    
    void NTT(int *a,int n,int f)
    {
       rep(i,0,n-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
       for(int i = 1;i < n;i <<= 1)
       {
          int w1 = qpow(f ? G : invG,(mod-1) / (i<<1));
          for(int j = 0;j < n;j += (i<<1))
          {
    	 int w = 1;
    	 rep(k,0,i-1)
    	 {
    	    int kx = a[k+j],ky = mul(a[k+j+i],w);
    	    a[k+j] = inc(kx,ky),a[k+j+i] = inc(kx,mod-ky),w = mul(w,w1);
    	 }
          }
       }
       if(!f)
       {
          int inv = qpow(n,mod-2);
          rep(i,0,n-1) a[i] = mul(a[i],inv);
       }
    }
    
    void getinv(int *a,int *b,int len)
    {
       if(len == 1) {b[0] = qpow(a[0],mod-2);return;}
       getinv(a,b,(len+1) >> 1);
       int l = 1,L = 0;
       while(l < (len<<1)) l <<= 1,L++;
       getrev(l,L);
       rep(i,0,len-1) C[i] = a[i];
       rep(i,len,l-1) C[i] = 0;
       NTT(C,l,1),NTT(b,l,1);
       rep(i,0,l-1) b[i] = mul(inc(2,mod-mul(C[i],b[i])),b[i]);
       NTT(b,l,0);
       rep(i,len,l-1) b[i] = 0;
    }
    /*
    void getsqrt(int *a,int *b,int len)
    {
       if(len == 1) {b[0] = 1;return;}
       getsqrt(a,b,(len+1)>>1);
       rep(i,0,len<<1) F[i] = 0;
       getinv(b,F,len);
       int l = 1,L = 0;
       while(l < len << 1) l <<= 1,L++;
       getrev(l,L);
       rep(i,0,len-1) D[i] = a[i];
       rep(i,len,l-1) D[i] = 0;
       NTT(D,l,1),NTT(b,l,1),NTT(F,l,1);
       rep(i,0,l-1) b[i] = mul(inc(b[i],mul(D[i],F[i])),inv2);
       NTT(b,l,0);
       rep(i,len,l-1) b[i] = 0;
    }
    */
    
    void getsqrt(int *a,int *b,int len)
    {
       if(len == 1) {b[0] = 1;return;}
       getsqrt(a,b,(len+1)>>1);
       rep(i,0,len<<1) F[i] = 0;
       getinv(b,F,len);
       int l = 1,L = 0;
       while(l < len<<1) l <<= 1,L++;
       getrev(l,L);
       NTT(b,l,1);
       rep(i,0,l-1) b[i] = mul(b[i],b[i]);
       NTT(b,l,0);
       rep(i,0,len-1) b[i] = inc(b[i],a[i]);
       NTT(b,l,1),NTT(F,l,1);
       rep(i,0,l-1) b[i] = mul(mul(b[i],F[i]),inv2);
       NTT(b,l,0);
       rep(i,len,l-1) b[i] = 0;
    }
    
    int main()
    {
       n = read(),inv2 = qpow(2,mod-2);
       rep(i,0,n-1) A[i] = read();
       getsqrt(A,B,n);
       rep(i,0,n-1) printf("%d ",B[i]);enter;
       return 0;
    }
    

    另一种做法是用(ln)(exp),它可以推广到任意次幂,不过常数会很大。
    对于(G(x)^k equiv F(x)),我们把式子变形可以得到:(F(x) equiv e^{kln(G(x))})
    于是我们就可以用(ln)(exp)来解决这个问题了……
    (代码暂时咕了,可能某些时候补上)

    6.多项式除法与取模
    这玩意的做法很神奇。
    给定一个多项式(A(x))和一个多项式(B(x)),求多项式(D(x))(R(x)),使得(A(x) = B(x)D(x) + R(x))
    首先我们要想办法消除(R(x))的影响。如何消除呢……
    我们假设A有n项,B有m项,且(m < n),那么显然D应该有n-m项,R有m-1项。
    我们引入一个神奇的操作:将所有的(x)(frac{1}{x})来替代.两边再同时乘以(x^n)
    容易发现,我们相当于将多项式的系数进行了反转。原来的式子变成了这样:
    (x^nA(frac{1}{x}) = x^mB(frac{1}{x})x^{n-m}D(frac{1}{x}) + x^{n-m+1}x^{m-1}R(frac{1}{x}))
    我们定义(A^r(x) = x^nA(frac{1}{x})),注意这个n对于每个多项式是不同的,指的是多项式自身的项数。
    于是就有(A^r(x) = B^r(x)D^r(x) + x^{n-m+1}R^r(x))
    我们发现,这个式子在(mod (n-m))意义下,(R(x))的影响就会被消除。而(D(x))在反转后,次数仍然不高于(n-m),所以我们有(A^r(x) equiv B^r(x)D^r(x) (mod x^{n-m+1}))
    然后求一下在模意义下B的逆元,倒着推回去就能求出(D(x))(R(x))了。

    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    #define I inline 
    
    using namespace std;
    typedef long long ll;
    const int M = 800005; 
    const int INF = 200000000;
    const int mod = 998244353;
    const int T = 3;
    const int invT = 332748118;
    
    int read()
    {
       int ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    int n,m,F[M],G[M],rev[M],Q[M],R[M],Fr[M],Gr[M],Gi[M],c[M];
    
    int add(int a,int b){return (a+b) % mod;}
    int mul(int a,int b){return 1ll * a * b % mod;}
    
    int qpow(int a,int b)
    {
       int p = 1;
       while(b)
       {
          if(b&1) p = mul(p,a);
          a = mul(a,a),b >>= 1;
       }
       return p;
    }
    
    void NTT(int *a,int n,int f)
    {
       rep(i,0,n-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
       for(int i = 1;i < n;i <<= 1)
       {
          int w1 = qpow(f ? T : invT,(mod-1) / (i<<1));
          for(int j = 0;j < n;j += (i<<1))
          {
         int w = 1;
         rep(k,0,i-1)
         {
            int kx = a[k+j],ky = mul(a[k+j+i],w);
            a[k+j] = add(kx,ky),a[k+j+i] = add(kx,mod-ky),w = mul(w,w1);
         }
          }
       }
       if(!f)
       {
          int inv = qpow(n,mod-2);
          rep(i,0,n-1) a[i] = mul(a[i],inv);
       }
    }
    
    void getrev(int n,int L)
    {
       rep(i,0,n-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
    }
    
    void getinv(int *a,int *b,int len)
    {
       if(len == 1) {b[0] = qpow(a[0],mod-2);return;}
       getinv(a,b,(len+1)>>1);
       int l = 1,L = 0;
       while(l < (len << 1)) l <<= 1,L++;
       getrev(l,L);
       rep(i,0,len-1) c[i] = a[i];
       rep(i,len,l-1) c[i] = 0;
       NTT(c,l,1),NTT(b,l,1);
       rep(i,0,l-1) b[i] = mul(add(2,mod-mul(c[i],b[i])),b[i]);
       NTT(b,l,0);
       rep(i,len,l-1) b[i] = 0;
    }
    
    int main()
    {
       n = read(),m = read();
       rep(i,0,n) F[i] = read(),Fr[n-i] = F[i];
       rep(i,0,m) G[i] = read(),Gr[m-i] = G[i];
       rep(i,n-m+2,m) Gr[i] = 0;
       getinv(Gr,Gi,n-m+1);
       int l = 1,L = 0;
       while(l <= (n<<1)) l <<= 1,L++;
       getrev(l,L);
       NTT(Fr,l,1),NTT(Gi,l,1);
       rep(i,0,l-1) Q[i] = mul(Fr[i],Gi[i]);
       NTT(Q,l,0);reverse(Q,Q+n-m+1);
       rep(i,n-m+1,n) Q[i] = 0;
       rep(i,0,n-m) printf("%d ",Q[i]);enter;
       l = 1,L = 0;
       while(l <= (n << 1)) l <<= 1,L++;
       getrev(l,L);
       NTT(G,l,1),NTT(Q,l,1);
       rep(i,0,l-1) G[i] = mul(G[i],Q[i]);
       NTT(G,l,0);
       rep(i,0,l-1) R[i] = add(F[i],mod-G[i]);
       rep(i,0,m-1) printf("%d ",R[i]);enter;
       return 0;
    }
    

    7.任意模数NTT。
    这个不知道会不会有毒瘤出题人这么出……就是他给你的模数不是NTT模数……
    MTT……?我好像不大会。我只会三模NTT,这种做法好像被(Shadowice1984)大佬疯狂批评,不过我暂时还不会别的哎……只好先写这个了。
    这其实是比较投机取巧的做法,因为每个数不超过(1e9),所以我们可以选三个大的NTT模数,在每个模数的模意义下算出答案,之后用CRT合并。
    咋子合并……假设答案分别为(x_1,x_2,x_3),模数分别为(A,B,C)
    就有如下方程:

    [x equiv x_1 (mod A) ]

    [x equiv x_2(mod B) ]

    [x equiv x_3(mod C) ]

    先合并前两个。

    [x_1 + k_1A = x_2 + k_2B ]

    [x_1 + k_1A equiv x_2 (mod B) ]

    [k_1 equiv frac{x_2 - x_1}{A} (mod B) ]

    求出(k_1)之后,令(x_4 = x_1 + k_1A),继续合并:

    [x_4 + k_4AB = x_3 + k_3C ]

    [x_4 + k_4AB equiv x_3(mod C) ]

    [k_4 equiv frac{x_3 - x_4}{AB} ]

    得到(k_4),就知道(x = x_4 + k_4AB (mod ABC)),因为(x < ABC),所以就有(x = x_4 + k_4AB (mod p))
    直接算完答案合并就可以。注意超出longlong的时候要用那个神奇的合并方式。

    // luogu-judger-enable-o2
    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    #define I inline 
    
    using namespace std;
    typedef long long ll;
    const int M = 800005; 
    const int INF = 200000000;
    const ll mod1 = 998244353,mod2 = 469762049,mod3 = 1004535809;
    const ll T = 3;
    
    ll read()
    {
       ll ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    ll n,m,p,F[M],G[M],rev[M],A[M],B[M],C[M],D[M];
    
    ll add(ll a,ll b,ll mod){return (a+b) % mod;}
    ll mul(ll a,ll b,ll mod){return ((a*b-(ll)((long double)a/mod*b+1e-8)*mod)+mod)%mod;}
    
    ll qpow(ll a,ll b,ll mod)
    {
       a %= mod;
       ll t = 1;
       while(b)
       {
          if(b&1) t = mul(t,a,mod);
          a = mul(a,a,mod),b >>= 1;
       }
       return t;
    }
    
    void getrev(int l,int L){rep(i,0,l-1) rev[i] = (rev[i>>1]>>1) | ((i&1) << (L-1));}
    
    void NTT(ll *a,ll n,ll f,ll mod)
    {
       rep(i,0,n-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
       ll invT = qpow(T,mod-2,mod);
       for(int i = 1;i < n;i <<= 1)
       {
          ll w1 = qpow(f? T : invT,(mod-1) / (i<<1),mod);
          for(int j = 0;j < n;j += (i<<1))
          {
         ll w = 1;
         rep(k,0,i-1)
         {
            ll kx = a[k+j]%mod,ky = mul(a[k+j+i],w,mod);
            a[k+j] = add(kx,ky,mod),a[k+j+i] = add(kx,mod-ky,mod),w = mul(w,w1,mod);
         }
          }
       }
       if(!f)
       {
          ll inv = qpow(n,mod-2,mod);
          //reverse(a+1,a+n);
          rep(i,0,n-1) a[i] = mul(a[i],inv,mod);
       }
    }
    
    void Polymul(ll *a,ll *b,ll l,ll mod)
    {
       NTT(a,l,1,mod),NTT(b,l,1,mod);
       rep(i,0,l-1) a[i] = mul(a[i],b[i],mod);
       NTT(a,l,0,mod);
    }
    
    void merge(ll *a,ll *b,ll *c,ll l,ll moda,ll modb,ll modc)
    {
       ll inv = qpow(moda,modb-2,modb);
       rep(i,0,l-1)
       {
          ll k1 = add(b[i],modb-(a[i]%modb),modb);
          k1 = mul(k1,inv,modb);
          a[i] = add(a[i],k1*moda,moda*modb);
       }
       inv = qpow(moda*modb,modc-2,modc);
       rep(i,0,l-1)
       {
          ll k2 = add(c[i],modc-(a[i]%modc),modc);
          k2 = mul(k2,inv,modc);
          k2 = mul(k2,moda*modb,p);
          a[i] = add(a[i],k2,p);
       }
    }
    
    int main()
    {
       n = read(),m = read(),p = read();
       rep(i,0,n) A[i] = read(),A[i] %= p,C[i] = F[i] = A[i];
       rep(i,0,m) B[i] = read(),B[i] %= p,D[i] = G[i] = B[i];
       int l = 1,L = 0;
       while(l <= n+m) l <<= 1,L++;
       getrev(l,L);
       Polymul(A,B,l,mod1);
       Polymul(C,D,l,mod2);
       Polymul(F,G,l,mod3);
       merge(A,C,F,n+m+1,mod1,mod2,mod3);
       rep(i,0,n+m) printf("%lld ",A[i]);enter;
       return 0;
    }
    

    8.分治FFT
    这玩意还是挺有用的。一般用于:已知(G(x)),求(F[i] = sum_{j=1}^{i}F[i-j]G[j]),F[0] = 1.
    这个怎么求呢?直接做应该是不行的……会超时。我们考虑分治,考虑一下式子的左半边计算出结果之后对于右半边的贡献,在计算之前把这些贡献加上即可,也就是一个CDQ分治套FFT。
    假设我们已经求出(l-mid)的答案,对于(mid-r)之中的一点x,其所获得的贡献为:(w_x = sum_{i=l}^{mid}f[i]g[x-i])
    所以我们做一遍CDQ套FFT就可以解决了。注意边界问题。

    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    #define lowbit(x) (x & (-x))
    
    using namespace std;
    typedef long long ll;
    const int M = 400005;
    const ll mod = 998244353;
    const ll G = 3;
    const ll invG = 332748118;
    
    ll read()
    {
       ll ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    ll n,g[M],f[M],r[M],a[M],b[M],rev[M];
    
    ll inc(ll a,ll b) {return (a + b) % mod;}
    ll mul(ll a,ll b) {return a * b % mod;}
    
    ll qpow(ll a,ll b)
    {
       ll p = 1;
       while(b)
       {
          if(b & 1) p = mul(p,a);
          a = mul(a,a),b >>= 1;
       }
       return p;
    }
    
    void NTT(ll *a,ll l,ll f)
    {
       rep(i,0,l-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
       for(int i = 1;i < l;i <<= 1)
       {
          ll w1 = qpow(f ? G : invG,(mod - 1) / (i<<1));
          for(int j = 0;j < l;j += (i<<1))
          {
         ll w = 1;
         rep(k,0,i-1)
         {
            ll kx = a[k+j],ky = mul(a[k+j+i],w);
            a[k+j] = inc(kx,ky),a[k+j+i] = inc(kx,mod-ky);
            w = mul(w,w1);
         }
          }
       }
       if(!f) 
       {
          ll inv = qpow(l,mod-2);
          rep(i,0,l-1) a[i] = mul(a[i],inv);
       }
    }
    
    void CDQ(int kl,int kr)
    {
       if(kl == kr) return;
       int mid = (kl + kr) >> 1;
       CDQ(kl,mid);
       int l = 1,L = 0;
       while(l < (kr - kl + 1) << 1) l <<= 1,L++;
       rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
       rep(i,0,l-1) a[i] = b[i] = 0;
       rep(i,kl,mid) a[i-kl] = f[i];
       rep(i,0,kr-kl) b[i] = g[i];
       NTT(a,l,1),NTT(b,l,1);
       rep(i,0,l-1) a[i] = mul(a[i],b[i]);
       NTT(a,l,0);
       rep(i,mid+1,kr) f[i] = inc(f[i],a[i-kl]);
       CDQ(mid+1,kr);
    }
    
    int main()
    {
       n = read(),f[0] = 1;
       rep(i,1,n-1) g[i] = read();
       CDQ(0,n-1);
       rep(i,0,n-1) printf("%lld ",f[i]);enter;
       return 0;
    }
    
  • 相关阅读:
    个人作业——软件评测
    软件工程实践2019第五次作业
    18年今日头条笔试第一题题解:球迷(fans)
    游戏2.1版本
    游戏2.0版本 代码
    游戏2.0版本
    改进版游戏代码
    改进版游戏
    2017.1.13之审判日
    找朋友 的内存超限代码
  • 原文地址:https://www.cnblogs.com/captain1/p/10350115.html
Copyright © 2011-2022 走看看