zoukankan      html  css  js  c++  java
  • 洛谷P4238【模板】多项式求逆

    洛谷P4238

    多项式求逆:http://blog.miskcoo.com/2015/05/polynomial-inverse

    注意:直接在点值表达下做$B(x) equiv 2B'(x) - A(x)B'^2(x) pmod {x^n}$是可以的,但是一定要注意,这一步中有一个长度为n的和两个长度为(n/2)的多项式相乘,因此要在DFT前就扩展FFT点值表达的“长度”到2n,否则会出错(调了1.5个小时)

    备份

    版本1:

     1 #prag
     2 ma GCC optimize(2)
     3 #include<cstdio>
     4 #include<algorithm>
     5 #include<cstring>
     6 #include<vector>
     7 #include<cmath>
     8 using namespace std;
     9 #define fi first
    10 #define se second
    11 #define mp make_pair
    12 #define pb push_back
    13 typedef long long ll;
    14 typedef unsigned long long ull;
    15 const int md=998244353;
    16 const int N=2097152;
    17 int rev[N];
    18 void init(int len)
    19 {
    20     int bit=0,i;
    21     while((1<<(bit+1))<=len)    ++bit;
    22     for(i=0;i<len;++i)
    23         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    24 }
    25 ll poww(ll a,ll b)
    26 {
    27     ll base=a,ans=1;
    28     for(;b;b>>=1,base=base*base%md)
    29         if(b&1)
    30             ans=ans*base%md;
    31     return ans;
    32 }
    33 void dft(int *a,int len,int idx)//要求len为2的幂
    34 {
    35     int i,j,k,t1,t2;ll wn,wnk;
    36     for(i=0;i<len;++i)
    37         if(i<rev[i])
    38             swap(a[i],a[rev[i]]);
    39     for(i=1;i<len;i<<=1)
    40     {
    41         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
    42         for(j=0;j<len;j+=(i<<1))
    43         {
    44             wnk=1;
    45             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
    46             {
    47                 t1=a[k];t2=a[k+i]*wnk%md;
    48                 a[k]+=t2;
    49                 (a[k]>=md) && (a[k]-=md);
    50                 a[k+i]=t1-t2;
    51                 (a[k+i]<0) && (a[k+i]+=md);
    52             }
    53         }
    54     }
    55     if(idx==-1)
    56     {
    57         ll ilen=poww(len,md-2);
    58         for(i=0;i<len;++i)
    59             a[i]=a[i]*ilen%md;
    60     }
    61 }
    62 int f[N],g[N],t1[N];
    63 int n,n1;
    64 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2^(ceil(log2(len))+1)(需要足够长用于临时存放元素) 
    65 {
    66     g[0]=poww(f[0],md-2);
    67     for(int i=2,j;i<(len<<1);i<<=1)
    68     {
    69         init(i<<1);
    70         memcpy(t1,f,sizeof(int)*i);
    71         memset(t1+i,0,sizeof(int)*i);
    72         memset(g+(i>>1),0,sizeof(int)*(i+(i>>1)));
    73         dft(t1,i<<1,1);dft(g,i<<1,1);
    74         for(j=0;j<(i<<1);++j)
    75             g[j]=ll(g[j])*(2+ll(md-g[j])*t1[j]%md)%md;
    76         dft(g,i<<1,-1);
    77     }
    78 }
    79 int main()
    80 {
    81     int i,t;
    82     scanf("%d",&n);n1=n;
    83     for(i=0;i<n;++i)
    84         scanf("%d",g+i);
    85     for(t=1;t<n;t<<=1);
    86     n=t;
    87     p_inv(g,f,n);
    88     for(i=0;i<n1;++i)
    89         printf("%d ",f[i]);
    90     return 0;
    91 }
    View Code

    资料:https://www.luogu.org/blog/user7035/duo-xiang-shi-zong-jie

    里面有一个迷之优化(代码好像和文字表述的不一样,很玄学,看不懂,被坑了...)

    牛顿迭代得到式子:$B(x) equiv B'(x)-B'(x)(A(x)B'(x)-1) pmod {x^n}$,其中B'(x)是上一次迭代的结果,B(x)是这一次的结果,A(x)是原多项式,n是这一次迭代得到的结果长度(设它是2的幂);设上一次迭代得到的结果长度为m=n/2

    看右边的$A(x)B'(x)-1$,可以知道它第0到m-1项都是0,现在只需要求它与B'(x)的乘积的前n位,可以把它”左移“m位,这样它和B'(x)长度都只有m,因此只需要做长度为n(而不是2n)的NTT,然后再”右移”回去

    如果与B'(x)相乘时不做长度为2n的NTT而做长度为n的NTT,那么可以发现结果刚好相当于正常结果(做长度为2n的NTT的结果取前n位)将前一半和后一半交换(未验证)

    (可以直接用算A(x)B'(x)时求出的B'(x)的DFT)(当然这样NTT次数从3次变成了5次...)

    版本2:(实测的确比版本1快)(另外把longlong都改成了unsignedlonglong)

      1 #prag
      2 ma GCC optimize(2)
      3 #include<cstdio>
      4 #include<algorithm>
      5 #include<cstring>
      6 #include<vector>
      7 #include<cmath>
      8 using namespace std;
      9 #define fi first
     10 #define se second
     11 #define mp make_pair
     12 #define pb push_back
     13 typedef long long ll;
     14 typedef unsigned long long ull;
     15 const int md=998244353;
     16 const int N=262144;
     17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md))
     18 int rev[N];
     19 void init(int len)
     20 {
     21     int bit=0,i;
     22     while((1<<(bit+1))<=len)    ++bit;
     23     for(i=0;i<len;++i)
     24         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
     25 }
     26 ull poww(ull a,ull b)
     27 {
     28     ull base=a,ans=1;
     29     for(;b;b>>=1,base=base*base%md)
     30         if(b&1)
     31             ans=ans*base%md;
     32     return ans;
     33 }
     34 void dft(int *a,int len,int idx)//要求len为2的幂
     35 {
     36     int i,j,k,t1,t2;ull wn,wnk;
     37     for(i=0;i<len;++i)
     38         if(i<rev[i])
     39             swap(a[i],a[rev[i]]);
     40     for(i=1;i<len;i<<=1)
     41     {
     42         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
     43         for(j=0;j<len;j+=(i<<1))
     44         {
     45             wnk=1;
     46             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
     47             {
     48                 t1=a[k];t2=a[k+i]*wnk%md;
     49                 a[k]+=t2;
     50                 (a[k]>=md) && (a[k]-=md);
     51                 a[k+i]=t1-t2;
     52                 (a[k+i]<0) && (a[k+i]+=md);
     53             }
     54         }
     55     }
     56     if(idx==-1)
     57     {
     58         ull ilen=poww(len,md-2);
     59         for(i=0;i<len;++i)
     60             a[i]=a[i]*ilen%md;
     61     }
     62 }
     63 int t1[N],t2[N];
     64 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2^(ceil(log2(len))+1)(需要足够长用于临时存放元素) ;要求len是2的幂
     65 {
     66     g[0]=poww(f[0],md-2);
     67     for(int i=2,j;i<(len<<1);i<<=1)
     68     {
     69         memcpy(t1,f,sizeof(int)*i);
     70         memcpy(t2,g,sizeof(int)*(i>>1));
     71         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
     72         init(i);
     73         dft(t1,i,1);dft(t2,i,1);
     74         for(j=0;j<i;++j)
     75             t1[j]=ull(t1[j])*t2[j]%md;
     76         dft(t1,i,-1);
     77         for(j=0;j<(i>>1);++j)
     78             t1[j]=t1[j+(i>>1)];
     79         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
     80         dft(t1,i,1);
     81         for(j=0;j<i;++j)
     82             t1[j]=ull(t1[j])*t2[j]%md;
     83         dft(t1,i,-1);
     84         for(j=i>>1;j<i;++j)
     85             delto(g[j],t1[j-(i>>1)]);
     86     }
     87 }
     88 int f[N],g[N];
     89 int n,n1;
     90 int main()
     91 {
     92     int i,t;
     93     scanf("%d",&n);n1=n;
     94     for(i=0;i<n;++i)
     95         scanf("%d",g+i);
     96     for(t=1;t<n;t<<=1);
     97     n=t;
     98     p_inv(g,f,n);
     99     for(i=0;i<n1;++i)
    100         printf("%d ",f[i]);
    101     return 0;
    102 }
    View Code

    版本3:基于此题版本2,改了疑似bug

      1 #prag
      2 ma GCC optimize(2)
      3 #include<cstdio>
      4 #include<algorithm>
      5 #include<cstring>
      6 #include<vector>
      7 #include<cmath>
      8 using namespace std;
      9 #define fi first
     10 #define se second
     11 #define mp make_pair
     12 #define pb push_back
     13 typedef long long ll;
     14 typedef unsigned long long ull;
     15 const int md=998244353;
     16 const int N=262144;
     17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md))
     18 int rev[N];
     19 void init(int len)
     20 {
     21     int bit=0,i;
     22     while((1<<(bit+1))<=len)    ++bit;
     23     for(i=0;i<len;++i)
     24         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
     25 }
     26 ull poww(ull a,ull b)
     27 {
     28     ull base=a,ans=1;
     29     for(;b;b>>=1,base=base*base%md)
     30         if(b&1)
     31             ans=ans*base%md;
     32     return ans;
     33 }
     34 void dft(int *a,int len,int idx)//要求len为2的幂
     35 {
     36     int i,j,k,t1,t2;ull wn,wnk;
     37     for(i=0;i<len;++i)
     38         if(i<rev[i])
     39             swap(a[i],a[rev[i]]);
     40     for(i=1;i<len;i<<=1)
     41     {
     42         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
     43         for(j=0;j<len;j+=(i<<1))
     44         {
     45             wnk=1;
     46             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
     47             {
     48                 t1=a[k];t2=a[k+i]*wnk%md;
     49                 a[k]+=t2;
     50                 (a[k]>=md) && (a[k]-=md);
     51                 a[k+i]=t1-t2;
     52                 (a[k+i]<0) && (a[k+i]+=md);
     53             }
     54         }
     55     }
     56     if(idx==-1)
     57     {
     58         ull ilen=poww(len,md-2);
     59         for(i=0;i<len;++i)
     60             a[i]=a[i]*ilen%md;
     61     }
     62 }
     63 int t1[N],t2[N];
     64 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素) ;要求len是2的幂
     65 {
     66     g[0]=poww(f[0],md-2);
     67     for(int i=2,j;i<(len<<1);i<<=1)
     68     {
     69         memcpy(t1,f,sizeof(int)*i);
     70         memcpy(t2,g,sizeof(int)*(i>>1));
     71         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
     72         init(i);
     73         dft(t1,i,1);dft(t2,i,1);
     74         for(j=0;j<i;++j)
     75             t1[j]=ull(t1[j])*t2[j]%md;
     76         dft(t1,i,-1);
     77         for(j=0;j<(i>>1);++j)
     78             t1[j]=t1[j+(i>>1)];
     79         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
     80         dft(t1,i,1);
     81         for(j=0;j<i;++j)
     82             t1[j]=ull(t1[j])*t2[j]%md;
     83         dft(t1,i,-1);
     84         for(j=i>>1;j<i;++j)
     85             g[j]=md-t1[j-(i>>1)];
     86     }
     87 }
     88 int f[N],g[N];
     89 int n,n1;
     90 int main()
     91 {
     92     int i,t;
     93     scanf("%d",&n);n1=n;
     94     for(i=0;i<n;++i)
     95         scanf("%d",g+i);
     96     for(t=1;t<n;t<<=1);
     97     n=t;
     98     p_inv(g,f,n);
     99     for(i=0;i<n1;++i)
    100         printf("%d ",f[i]);
    101     return 0;
    102 }
    View Code
  • 相关阅读:
    try catch finally
    类的小练习
    易混淆概念总结
    C#中struct和class的区别详解
    Doing Homework again
    悼念512汶川大地震遇难同胞——老人是真饿了
    Repair the Wall
    Saving HDU
    JAVA-JSP隐式对象
    JAVA-JSP动作
  • 原文地址:https://www.cnblogs.com/hehe54321/p/10353385.html
Copyright © 2011-2022 走看看