zoukankan      html  css  js  c++  java
  • 多项式相关

    省选前准备把多项式搞完。(似乎够折磨人的)


    首先FFT和NTT板子请出门右转:还是我的博客


    0.加减乘
    加减有人用我说吗。
    乘就是FFT。


    1.多项式求逆
    对于一个(n - 1)次多项式(A(x)),求另一个多项式(B(x)),满足(A(x) * B(x) equiv 1 (mod x ^n))。(这里的取模表示只保留小于(n)的项的系数)
    做法就叫倍增吧。
    假设已经求得(B(x)),满足(A(x) * B(x) equiv 1 (mod x ^{lceil frac{n}{2} ceil}))(我也不知道为啥上取整),要求(C(x))满足(A(x) * C(x) equiv 1 (mod x ^n))
    首先可以有(C(x) equiv B(x) (mod x ^ {lceil frac{n}{2} ceil}))
    然后移项平方得((B - C) ^ 2 equiv 0 (mod x ^ n))。这个觉得挺显然的,因为上面的多项式可以看成最高次项为(lceil frac{n}{2} ceil - 1)的一个多项式,只不过系数都是0。然后平方就搞出了一个(n)次多项式,系数自然也是0。
    接着拆开,两边同乘以A得:(AB ^ 2 - 2B + C = 0)
    于是(C = B *(2 - AB))
    到此为止就做完了。
    因为想求(n)意义下的就必须先求(mod lceil frac{n}{2} ceil),所以采用递归求解。递归边界就是只有一项,那么(B(0) = A(0) ^ {mod - 2})
    其中的乘法用NTT解决。
    时间复杂度:(T(n) = T(frac{n}{2}) + O(nlogn) = O(nlogn))……不会证。

    #include<cstdio>
    #include<iostream>
    #include<cmath>
    #include<algorithm>
    #include<cstring>
    #include<cstdlib>
    #include<cctype>
    #include<vector>
    #include<stack>
    #include<queue>
    using namespace std;
    #define enter puts("") 
    #define space putchar(' ')
    #define Mem(a, x) memset(a, x, sizeof(a))
    #define In inline
    typedef long long ll;
    typedef double db;
    const int INF = 0x3f3f3f3f;
    const db eps = 1e-8;
    const ll mod = 998244353;
    const ll G = 3;
    const int maxn = 2e6 + 5;
    inline ll read()
    {
      ll ans = 0;
      char ch = getchar(), last = ' ';
      while(!isdigit(ch)) last = ch, ch = getchar();
      while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
      if(last == '-') ans = -ans;
      return ans;
    }
    inline void write(ll x)
    {
      if(x < 0) x = -x, putchar('-');
      if(x >= 10) write(x / 10);
      putchar(x % 10 + '0');
    }
    
    int n;
    int rev[maxn];
    ll a[maxn], b[maxn], c[maxn];
    
    In ll quickpow(ll a, ll b)
    {
      ll ret = 1;
      for(; b; b >>= 1, a = a * a % mod)
        if(b & 1) ret = ret * a % mod;
      return ret;
    }
    
    In void ntt(ll* a, int len, int flg)
    {
      for(int i = 0; i < len; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
      for(int i = 1; i < len; i <<= 1) 
        {
          ll gn = quickpow(G, (mod - 1) / (i << 1));
          for(int j = 0; j < len; j += (i << 1))
    	{
    	  ll tp1, tp2, g = 1;
    	  for(int k = 0; k < i; ++k, g = g * gn % mod)
    	    {
    	      tp1 = a[j + k], tp2 = g * a[j + k + i] % mod;
    	      a[j + k] = (tp1 + tp2) % mod, a[j + k + i] = (tp1 - tp2 + mod) % mod;
    	    }
    	}
        }
      if(flg == 1) return;
      int inv = quickpow(len, mod - 2); reverse(a + 1, a + len);
      for(int i = 0; i < len; ++i) a[i] = a[i] * inv % mod;
    }
    In void solve(int deg, ll* a, ll* b)	//最后的答案储存在B中,而不是C,这里的C只是充当了一个临时数组 
    {
      if(deg == 1) {b[0] = quickpow(a[0], mod - 2); return;}
      solve((deg + 1) >> 1, a, b);
      int lim = 0, len = 1;
      while(len < (deg << 1)) len <<= 1, ++lim;
      for(int i = 1; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1));
      for(int i = 0; i < deg; ++i) c[i] = a[i];
      for(int i = deg; i < len; ++i) c[i] = 0;
      ntt(c, len, 1), ntt(b, len, 1);
      for(int i = 0; i < len; ++i) b[i] = (2 * 1LL - c[i] * b[i] % mod + mod) % mod * b[i] % mod;
      ntt(b, len, -1);
      for(int i = deg; i < len; ++i) b[i] = 0;
    }
    
    int main()
    {
      n = read();
      for(int i = 0; i < n; ++i) a[i] = read();
      solve(n, a, b);
      for(int i = 0; i < n; ++i) write(b[i]), space; enter;
      return 0;
    }
    

    2.多项式开根

    对于(A(x)),找一个多项式(B(x))满足(B ^ 2 (x) equiv A(x) (mod n))
    做法和上面很像。
    首先有((B ^ 2 - A) ^ 2 equiv 0 (mod n))
    然后用初中知识得:((B ^ 2 + A) ^ 2 equiv 4AB ^ 2 (mod n))
    移项得(A = (frac{B ^ 2 + A}{2B}))
    (A)换成(C ^ 2)就完事了:(C = frac{B ^ 2 + A}{2B})
    这个除法就用刚学的多项式求逆就好啦。
    复杂度还是(O(nlogn))的,因为上面说了,每一层求逆元是(O(nlogn))的。
    “递归套递归,复杂度不变”(带劲)

    #include<cstdio>
    #include<iostream>
    #include<cmath>
    #include<algorithm>
    #include<cstring>
    #include<cstdlib>
    #include<cctype>
    #include<vector>
    #include<stack>
    #include<queue>
    using namespace std;
    #define enter puts("") 
    #define space putchar(' ')
    #define Mem(a, x) memset(a, x, sizeof(a))
    #define In inline
    typedef long long ll;
    typedef double db;
    const int INF = 0x3f3f3f3f;
    const db eps = 1e-8;
    const int maxn = 2e6 + 5;
    const ll mod = 998244353;
    const ll inv2 = 499122177;
    const ll G = 3;
    inline ll read()
    {
      ll ans = 0;
      char ch = getchar(), last = ' ';
      while(!isdigit(ch)) last = ch, ch = getchar();
      while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
      if(last == '-') ans = -ans;
      return ans;
    }
    inline void write(ll x)
    {
      if(x < 0) x = -x, putchar('-');
      if(x >= 10) write(x / 10);
      putchar(x % 10 + '0');
    }
    
    int n, rev[maxn];
    ll a[maxn], b[maxn], c[maxn], tp[maxn];
    
    In ll quickpow(ll a, ll b)
    {
      ll ret = 1;
      for(; b; b >>= 1, a = a * a % mod)
        if(b & 1) ret = ret * a % mod;
      return ret;
    }
    
    In void ntt(ll* a, int len, bool flg)
    {
      for(int i = 0; i < len; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
      for(int i = 1; i < len; i <<= 1)
        {
          ll ng = quickpow(G, (mod - 1) / (i << 1));
          for(int j = 0; j < len; j += (i << 1))
    	{
    	  ll g = 1;
    	  for(int k = 0; k < i; ++k, g = g * ng % mod)
    	    {
    	      ll tp1 = a[k + j], tp2 = a[k + j + i] * g % mod;
    	      a[k + j] = (tp1 + tp2) % mod; a[k + j + i] = (tp1 - tp2 + mod) % mod;
    	    }
    	}
        }
      if(flg) return;
      ll inv = quickpow(len, mod - 2); reverse(a + 1, a + len);
      for(int i = 0; i < len; ++i) a[i] = a[i] * inv % mod;
    }
    
    In void sol_inv(ll* a, ll* b, int deg)
    {
      if(deg == 1) {b[0] = quickpow(a[0], mod - 2); return;}
      sol_inv(a, b, (deg + 1) >> 1);
      int len = 1, lim = 0;
      while(len < (deg << 1)) len <<= 1, ++lim;
      for(int i = 0; i < deg; ++i) tp[i] = a[i];
      for(int i = deg; i < len; ++i) tp[i] = 0;
      for(int i = 0; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1));
      ntt(tp, len, 1); ntt(b, len, 1);
      for(int i = 0; i < len; ++i) b[i] = b[i] * (2 * 1LL - tp[i] * b[i] % mod + mod) % mod;
      ntt(b, len, 0);
      for(int i = deg; i < len; ++i) b[i] = 0;
    }
    
    In void sol_sqrt(ll* a, ll* b, int deg)
    {
      if(deg == 1) {b[0] = 1; return;}
      sol_sqrt(a, b, (deg + 1) >> 1);
      for(int i = 0; i < (deg << 1); ++i) c[i] = 0;
      sol_inv(b, c, deg);  
      int len = 1, lim = 0;
      while(len < (deg << 1)) len <<= 1, ++lim;
      for(int i = 0; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1));
      ntt(b, len, 1);
      for(int i = 0; i < len; ++i) b[i] = b[i] * b[i] % mod;
      ntt(b, len, 0);
      for(int i = 0; i < deg; ++i) b[i] = (b[i] + a[i]) % mod;	//循环不是到len! 
      ntt(b, len, 1); ntt(c, len, 1);
      for(int i = 0; i < len; ++i) b[i] = b[i] * c[i] % mod * inv2 % mod;
      ntt(b, len, 0);
      for(int i = deg; i < len; ++i) b[i] = 0;
    }
    
    int main()
    {
      n = read();
      for(int i = 0; i < n; ++i) a[i] = read();
      sol_sqrt(a, b, n);
      for(int i = 0; i < n; ++i) write(b[i]), space; enter;
      return 0;
    }
    

    3.多项式求导
    (原谅我不会求导,现学的)
    首先得知道幂函数求导:((x ^ a)' = ax ^ {a - 1})
    所以每一项的导数就是((ax ^ b)' = abx ^ {b - 1})
    每一项的导数相加就是多项式的导数(简单不)。

    In void get_der(ll* a, ll* b, int n)	//导数英文derivation,特意查的 
    {
    	for(int i = 1; i < n; ++i) b[i - 1] = a[i] * i % mod;
    	b[n - 1] = 0;
    }
    

    4.多项式积分
    (积分就是谁的导数是我——送给像我一样不会积分的人)
    知道导数怎么求了,我们就可以反推积分(仔细想想):(int ax ^ b = frac{a}{b + 1} x ^ {b + 1})
    同理每一项加起来就是多项式积分

    In void get_int(ll* a, ll* b, int n)    //积分叫integral啦
    {
    	for(int i = 1; i < n; ++i) b[i] = a[i - 1] * inv[i] % mod;
    	b[0] = 0;
    }
    

    5.多项式取对数
    直接上公式:(In A = int frac{A'}{A})
    也就是先求导再乘以逆元最后积分一下。

    ll tp1[maxn], tp2[maxn];
    In void get_In(ll* a, ll* b, int n)
    {
    	get_der(a, tp1, n); sol_inv(n, a, tp2);
    	int len = 1, lim = 0;
    	while(len < (n << 1)) len <<= 1, ++lim;
    	for(int i = 0; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1));
    	ntt(tp1, len, 1); ntt(tp2, len, 1);
    	for(int i = 0; i < len; ++i) tp1[i] = tp1[i] * tp2[i] % mod;
    	ntt(tp1, len, 0);
    	get_int(tp1, b, n);
    }
    
  • 相关阅读:
    Ruby 2
    Ruby 1
    莱布尼兹:与牛顿争吵了一生的斗士 微积分的奠基人之一―莱布尼茨
    如何实现html页面自动刷新
    css z-index的层级关系
    让网页变灰的实现_网站蒙灰CSS样式总汇
    利用CSS变量实现炫酷的悬浮效果
    离线电商数仓(十四)之系统业务数据仓库数据采集(一)电商业务简介
    离线电商数仓(十三)之用户行为数据采集(十三)采集通道启动/停止脚本
    离线电商数仓(十)之用户行为数据采集(十)组件安装(六)采集日志Flume(二)消费Kafka数据Flume
  • 原文地址:https://www.cnblogs.com/mrclr/p/10367006.html
Copyright © 2011-2022 走看看