P4238 【模板】多项式求逆
分析:
代码:700ms
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<cmath> 5 #include<iostream> 6 7 using namespace std; 8 9 typedef long long LL; 10 11 const int N = 2100000; 12 const int P = 998244353; 13 const int G = 3; 14 const int Gi = 332748118; 15 int A[N],B[N],TA[N],TB[N]; 16 17 inline int read() { 18 int x = 0,f = 1;char ch=getchar(); 19 for (; !isdigit(ch); ch=getchar()) if(ch=='-')f=-1; 20 for (; isdigit(ch); ch=getchar()) x=x*10+ch-'0'; 21 return x*f; 22 } 23 int ksm(int a,int b) { 24 int ans = 1; 25 while (b) { 26 if (b & 1) ans = (1ll * ans * a) % P; 27 a = (1ll * a * a) % P; 28 b >>= 1; 29 } 30 return ans % P; 31 } 32 void NTT(int *a,int n,int ty) { 33 for (int i=0,j=0; i<n; ++i) { 34 if (i < j) swap(a[i],a[j]); 35 for (int k=(n>>1); (j^=k)<k; k>>=1); 36 } 37 for (int w1,w,m=2; m<=n; m<<=1) { 38 if (ty==1) w1 = ksm(G,(P-1)/m); 39 else w1 = ksm(Gi,(P-1)/m); 40 for (int i=0; i<n; i+=m) { 41 w = 1; 42 for (int k=0; k<(m>>1); ++k) { 43 int u = a[i+k],t = 1ll * w * a[i+k+(m>>1)] % P; 44 a[i+k] = (u + t) % P; 45 a[i+k+(m>>1)] = (u - t + P) % P; 46 w = 1ll * w * w1 % P; 47 } 48 } 49 } 50 if (ty==-1) { 51 int inv = ksm(n,P-2); 52 for (int i=0; i<n; ++i) a[i] = 1ll * a[i] * inv % P; 53 } 54 } 55 int main() { 56 int n = read(),len = 1; 57 for (int i=0; i<n; ++i) A[i] = read(); 58 59 while (len <= n) len <<= 1; 60 61 B[0] = ksm(A[0],P-2); 62 for (int m=2; m<=len; m<<=1) { 63 for (int i=0; i<m; ++i) TA[i] = A[i],TB[i] = B[i]; 64 NTT(TA,m<<1,1); 65 NTT(TB,m<<1,1); 66 for (int i=0; i<(m<<1); ++i) TA[i] = 1ll*TA[i]*TB[i]%P*TB[i]%P; // A * B * B 67 NTT(TA,m<<1,-1); 68 for (int i=0; i<m; ++i) B[i] = (1ll*2*B[i]%P-TA[i]+P)%P; // 多项式减法 69 } 70 for (int i=0; i<n; ++i) printf("%d ",B[i]); 71 return 0; 72 }
感觉优化到不能优化的代码:520ms
1 #include<cstdio> 2 #include<algorithm> 3 #include<cctype> 4 5 #define G 3 6 #define Gi 332748118 7 #define N 270000 8 #define P 998244353 9 #define LL long long 10 #define rg register 11 #define add(a, b) (a + b >= P ? a + b - P : a + b) 12 #define dec(a, b) (a - b < 0 ? a - b + P : a - b) 13 #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2) ? EOF :*p1++) 14 15 using namespace std; 16 17 int A[N],B[N],TA[N],TB[N],rev[N],KSMG[N],KSMGI[N]; 18 19 char ch,buf[100000],*p1 = buf,*p2 = buf;; 20 inline int read() { 21 int x = 0,f = 1;char ch=getchar(); 22 for (; !isdigit(ch); ch=getchar()) if(ch=='-')f=-1; 23 for (; isdigit(ch); ch=getchar()) x=x*10+ch-'0'; 24 return x*f; 25 } 26 char obuf[1<<24], *O=obuf; 27 void print(int x) { 28 if(x > 9) print(x / 10); 29 *O++= x % 10 + '0'; 30 } 31 inline int ksm(int a,int b) { 32 int ans = 1; 33 while (b) { 34 if (b & 1) ans = (1ll * ans * a) % P; 35 a = (1ll * a * a) % P; 36 b >>= 1; 37 } 38 return ans % P; 39 } 40 void NTT(int *a,int n,int ty,int L) { 41 for(rg int i=1; i<n; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<L-1); 42 for(rg int i=1; i<n; ++i) if(i<rev[i]) std::swap(a[i],a[rev[i]]); 43 for (rg int w1,w,m=2; m<=n; m<<=1) { 44 if (ty==1) w1 = KSMG[m];else w1 = KSMGI[m]; 45 for (int i=0; i<n; i+=m) { 46 w = 1; 47 for (rg int k=0; k<(m>>1); ++k) { 48 int u = a[i+k],t = 1ll * w * a[i+k+(m>>1)] % P; 49 a[i+k] = add(u, t); 50 a[i+k+(m>>1)] = dec(u, t); 51 w = 1ll * w * w1 % P; 52 } 53 } 54 } 55 if (ty==-1) { 56 int inv = ksm(n,P-2); 57 for (rg int i=0; i<n; ++i) a[i] = 1ll * a[i] * inv % P; 58 } 59 } 60 int main() { 61 int n = read(),len = 1; 62 for (rg int i=0; i<n; ++i) A[i] = read(); 63 64 while (len <= n) len <<= 1; 65 int tmp = len << 1; 66 for (rg int i=1; i<=tmp; i<<=1) 67 KSMG[i] = ksm(G,(P-1)/i),KSMGI[i] = ksm(Gi,(P-1)/i); 68 69 B[0] = ksm(A[0],P-2); 70 int t = 1; 71 for (rg int m=2; m<=len; m<<=1) { // 求长度为m的逆元 72 t ++; 73 for (rg int i=0; i<m; ++i) TA[i] = A[i],TB[i] = B[i]; 74 NTT(TA,m<<1,1,t); 75 NTT(TB,m<<1,1,t); 76 for (rg int i=0; i<(m<<1); ++i) TA[i] = 1ll*TA[i]*TB[i]%P*TB[i]%P; // A * B * B 77 NTT(TA,m<<1,-1,t); 78 for (rg int i=0; i<m; ++i) B[i] = (1ll*2*B[i]%P-TA[i]+P)%P; // 多项式减法 79 } 80 for(rg int i = 0; i < n; i++) print(B[i]), *O++ = ' '; 81 fwrite(obuf, O-obuf, 1 , stdout); 82 return 0; 83 }