zoukankan      html  css  js  c++  java
  • P4238 【模板】多项式求逆

    P4238 【模板】多项式求逆

    链接

    分析:

      多项式求逆元

    代码:700ms

     1 #include<cstdio>
     2 #include<algorithm>
     3 #include<cstring>
     4 #include<cmath>
     5 #include<iostream>
     6 
     7 using namespace std;
     8 
     9 typedef long long LL;
    10 
    11 const int N = 2100000;
    12 const int P = 998244353;
    13 const int G = 3;
    14 const int Gi = 332748118;
    15 int A[N],B[N],TA[N],TB[N];
    16 
    17 inline int read() {
    18     int x = 0,f = 1;char ch=getchar();
    19     for (; !isdigit(ch); ch=getchar()) if(ch=='-')f=-1;
    20     for (; isdigit(ch); ch=getchar()) x=x*10+ch-'0';
    21     return x*f;
    22 }
    23 int ksm(int a,int b) {
    24     int ans = 1;
    25     while (b) {
    26         if (b & 1) ans = (1ll * ans * a) % P;
    27         a = (1ll * a * a) % P;
    28         b >>= 1;
    29     }
    30     return ans % P;
    31 }
    32 void NTT(int *a,int n,int ty) {
    33     for (int i=0,j=0; i<n; ++i) {
    34         if (i < j) swap(a[i],a[j]);
    35         for (int k=(n>>1); (j^=k)<k; k>>=1);
    36     }
    37     for (int w1,w,m=2; m<=n; m<<=1) {
    38         if (ty==1) w1 = ksm(G,(P-1)/m);
    39         else w1 = ksm(Gi,(P-1)/m);
    40         for (int i=0; i<n; i+=m) {
    41             w = 1;
    42             for (int k=0; k<(m>>1); ++k) {
    43                 int u = a[i+k],t = 1ll * w * a[i+k+(m>>1)] % P;
    44                 a[i+k] = (u + t) % P;
    45                 a[i+k+(m>>1)] = (u - t + P) % P;
    46                 w = 1ll * w * w1 % P;
    47             }
    48         }
    49     }
    50     if (ty==-1) {
    51         int inv = ksm(n,P-2);
    52         for (int i=0; i<n; ++i) a[i] = 1ll * a[i] * inv % P;
    53     }
    54 }
    55 int main() {
    56     int n = read(),len = 1;
    57     for (int i=0; i<n; ++i) A[i] = read();
    58     
    59     while (len <= n) len <<= 1;
    60     
    61     B[0] = ksm(A[0],P-2);
    62     for (int m=2; m<=len; m<<=1) {
    63         for (int i=0; i<m; ++i) TA[i] = A[i],TB[i] = B[i];
    64         NTT(TA,m<<1,1);
    65         NTT(TB,m<<1,1);
    66         for (int i=0; i<(m<<1); ++i) TA[i] = 1ll*TA[i]*TB[i]%P*TB[i]%P; // A * B * B
    67         NTT(TA,m<<1,-1);
    68         for (int i=0; i<m; ++i) B[i] = (1ll*2*B[i]%P-TA[i]+P)%P; // 多项式减法 
    69     }
    70     for (int i=0; i<n; ++i) printf("%d ",B[i]);    
    71     return 0;
    72 }
    View Code

    感觉优化到不能优化的代码:520ms

     1 #include<cstdio>
     2 #include<algorithm>
     3 #include<cctype>
     4 
     5 #define G 3
     6 #define Gi 332748118
     7 #define N 270000
     8 #define P 998244353
     9 #define LL long long 
    10 #define rg register 
    11 #define add(a, b) (a + b >= P ? a + b - P : a + b)
    12 #define dec(a, b) (a - b <  0 ? a - b + P : a - b)
    13 #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2) ? EOF :*p1++)
    14     
    15 using namespace std;
    16 
    17 int A[N],B[N],TA[N],TB[N],rev[N],KSMG[N],KSMGI[N];
    18 
    19 char ch,buf[100000],*p1 = buf,*p2 = buf;;
    20 inline int read() {
    21     int x = 0,f = 1;char ch=getchar();
    22     for (; !isdigit(ch); ch=getchar()) if(ch=='-')f=-1;
    23     for (; isdigit(ch); ch=getchar()) x=x*10+ch-'0';
    24     return x*f;
    25 }
    26 char obuf[1<<24], *O=obuf;
    27 void print(int x) {
    28     if(x > 9) print(x / 10);
    29     *O++= x % 10 + '0';
    30 }
    31 inline int ksm(int a,int b) {
    32     int ans = 1;
    33     while (b) {
    34         if (b & 1) ans = (1ll * ans * a) % P;
    35         a = (1ll * a * a) % P;
    36         b >>= 1;
    37     }
    38     return ans % P;
    39 }
    40 void NTT(int *a,int n,int ty,int L) {
    41     for(rg int i=1; i<n; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<L-1);
    42     for(rg int i=1; i<n; ++i) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
    43     for (rg int w1,w,m=2; m<=n; m<<=1) {
    44         if (ty==1) w1 = KSMG[m];else w1 = KSMGI[m];
    45         for (int i=0; i<n; i+=m) {
    46             w = 1;
    47             for (rg int k=0; k<(m>>1); ++k) {
    48                 int u = a[i+k],t = 1ll * w * a[i+k+(m>>1)] % P;
    49                 a[i+k] = add(u, t);
    50                 a[i+k+(m>>1)] = dec(u, t);
    51                 w = 1ll * w * w1 % P;
    52             }
    53         }
    54     }
    55     if (ty==-1) {
    56         int inv = ksm(n,P-2);
    57         for (rg int i=0; i<n; ++i) a[i] = 1ll * a[i] * inv % P;
    58     }
    59 }
    60 int main() {
    61     int n = read(),len = 1;
    62     for (rg int i=0; i<n; ++i) A[i] = read();
    63     
    64     while (len <= n) len <<= 1;
    65     int tmp = len << 1;
    66     for (rg int i=1; i<=tmp; i<<=1)    
    67         KSMG[i] = ksm(G,(P-1)/i),KSMGI[i] = ksm(Gi,(P-1)/i);
    68         
    69     B[0] = ksm(A[0],P-2);
    70     int t = 1;
    71     for (rg int m=2; m<=len; m<<=1) { // 求长度为m的逆元 
    72         t ++;
    73         for (rg int i=0; i<m; ++i) TA[i] = A[i],TB[i] = B[i];
    74         NTT(TA,m<<1,1,t);
    75         NTT(TB,m<<1,1,t);
    76         for (rg int i=0; i<(m<<1); ++i) TA[i] = 1ll*TA[i]*TB[i]%P*TB[i]%P; // A * B * B
    77         NTT(TA,m<<1,-1,t);
    78         for (rg int i=0; i<m; ++i) B[i] = (1ll*2*B[i]%P-TA[i]+P)%P; // 多项式减法 
    79     }
    80     for(rg int i = 0; i < n; i++) print(B[i]), *O++ = ' ';
    81     fwrite(obuf, O-obuf, 1 , stdout);
    82     return 0;
    83 }
    View Code
  • 相关阅读:
    数据库如何部署上线阅读总结
    Nginx解决防盗链,服务器宕机,跨域,防DDOS
    跨域和表单重复提交
    Socet
    Redis发布订阅
    MySQL和Oracle的区别
    Redis事务、持久化、发布订阅
    Redis主从复制和哨兵模式
    Idea中使用Redis的Java客户端和Jedis
    Redis介绍及命令
  • 原文地址:https://www.cnblogs.com/mjtcn/p/9155806.html
Copyright © 2011-2022 走看看