写在前面
多项式求逆
前置知识:NTT
多项式求逆
给定一个多项式 (Fleft(x ight)),求一个多项式 (Gleft(x ight)),使得 (Fleft(x ight)Gleft(x ight)equiv 1left(mod 998244353 ight))。
考虑递归求解。
假定现在已经求出了 (G_0left(x ight)),满足
[Fleft(x
ight)G_0left(x
ight)equiv 1left(mod x^{lceilfrac{n}{2}
ceil}
ight) ag 1
]
根据要求的 (Gleft(x ight)) 的定义,显然有
[Fleft(x
ight)Gleft(x
ight) equiv 1left(mod x^{lceilfrac{n}{2}
ceil}
ight) ag 2
]
((2) - (1)),得
[Fleft(x
ight)left(Gleft(x
ight) - G_0left(x
ight)
ight) equiv 0 left(mod x^{lceilfrac{n}{2}
ceil}
ight)
]
因为 (Fleft(x ight) otequiv 0left(mod x^{lceil frac{n}{2} ceil} ight)),所以有
[Gleft(x
ight) - G_0left(x
ight) equiv 0 left(mod x^{lceilfrac{n}{2}
ceil}
ight)
]
两边平方,得
[G^2left(x
ight) - 2Gleft(x
ight)G_0left(x
ight) + G_0^2left(x
ight) equiv 0left(mod x^n
ight)
]
两边同乘 (Fleft(x ight)),得
[Gleft(x
ight) - 2G_0left(x
ight) + Fleft(x
ight)G_0^2left(x
ight) equiv 0left(mod x^n
ight)
]
移项整理
[Gleft(x
ight) equiv 2G_0left(x
ight) - Fleft(x
ight)G_0^2left(x
ight) left(mod x^n
ight)
]
递归处理之后自下而上递推即可。
代码:
int rev[Maxn];
void Setrev(int len) {
for(int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if(i & 1) rev[i] |= (len >> 1);
}
}
void ntt(LL p[], int len, int type) {
for(int i = 0; i < len; ++i) if(i < rev[i]) swap(p[i], p[rev[i]]);
for(int h = 2; h <= len; h <<= 1) {
LL gn = qpow(g[type], (Mod - 1) / h);
for(int j = 0; j < len; j += h) {
LL gk = 1;
for(int k = j; k < j + h / 2; ++k) {
LL e = p[k] % Mod, o = gk * p[k + h / 2] % Mod;
p[k] = (e + o) % Mod; p[k + h / 2] = ((e - o) % Mod + Mod) % Mod;
gk = gk * gn % Mod;
}
}
}
if(type == 1) {
LL invl = qpow(len, Mod - 2);
for(int i = 0; i < len; ++i) p[i] = p[i] * invl % Mod;
}
}
LL tmp[Maxn];
void polyinv(LL A[], LL B[], int siz) {
if(siz == 1) {B[0] = qpow(A[0], Mod - 2); return;}
polyinv(A, B, (siz + 1) >> 1);
int len = 1, L = (siz << 1); while(L) L >>= 1, len <<= 1;
for(int i = 0; i < siz; ++i) tmp[i] = A[i];
for(int i = siz; i < len; ++i) tmp[i] = 0;
Setrev(len); ntt(tmp, len, 0); ntt(B, len, 0);
for(int i = 0; i < len; ++i) B[i] = ((2ll * B[i] % Mod - B[i] * B[i] % Mod * tmp[i] % Mod) % Mod + Mod) % Mod;
ntt(B, len, 1);
for(int i = siz; i < len; ++i) B[i] = 0;
}
实现上的一些小细节
-
注意多项式长度,在算法没有问题的时候,长度稍微长了些并不会影响多项式求逆的结果。
-
最后那一步记得把 B 数组无用的元素清空。
-
虽然看上去用了多次 NTT,但是根据主定理(如有需要请自行搜索),复杂度仍旧是 (mathcal Oleft(n log n ight)) 的。
完整代码
#include <bits/stdc++.h>
#define LL long long
using namespace std;
template <typename Temp> inline void read(Temp & res) {
Temp fh = 1; res = 0; char ch = getchar();
for(; !isdigit(ch); ch = getchar()) if(ch == '-') fh = -1;
for(; isdigit(ch); ch = getchar()) res = (res << 3) + (res << 1) + (ch ^ '0');
res = res * fh;
}
const int Maxn = 262200;
const LL Mod = 998244353, g[2] = {3, 332748118};
LL qpow(LL A, LL P) {
LL res = 1;
while(P) {
if(P & 1) res = res * A % Mod;
A = A * A % Mod;
P >>= 1;
}
return res;
}
namespace Polynomial {
int rev[Maxn];
void Setrev(int len) {
for(int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if(i & 1) rev[i] |= (len >> 1);
}
}
void ntt(LL p[], int len, int type) {
for(int i = 0; i < len; ++i) if(i < rev[i]) swap(p[i], p[rev[i]]);
for(int h = 2; h <= len; h <<= 1) {
LL gn = qpow(g[type], (Mod - 1) / h);
for(int j = 0; j < len; j += h) {
LL gk = 1;
for(int k = j; k < j + h / 2; ++k) {
LL e = p[k] % Mod, o = gk * p[k + h / 2] % Mod;
p[k] = (e + o) % Mod; p[k + h / 2] = ((e - o) % Mod + Mod) % Mod;
gk = gk * gn % Mod;
}
}
}
if(type == 1) {
LL invl = qpow(len, Mod - 2);
for(int i = 0; i < len; ++i) p[i] = p[i] * invl % Mod;
}
}
void polymul(LL A[], LL B[], int siz) {
int len = 1; while(siz) siz >>= 1, len <<= 1;
Setrev(len); ntt(A, len, 0); ntt(B, len, 0);
for(int i = 0; i < len; ++i) A[i] = A[i] * B[i] % Mod;
ntt(A, len, 1);
}
LL tmp[Maxn];
void polyinv(LL A[], LL B[], int siz) {
if(siz == 1) {B[0] = qpow(A[0], Mod - 2); return;}
polyinv(A, B, (siz + 1) >> 1);
int len = 1, L = (siz << 1); while(L) L >>= 1, len <<= 1;
for(int i = 0; i < siz; ++i) tmp[i] = A[i];
for(int i = siz; i < len; ++i) tmp[i] = 0;
Setrev(len); ntt(tmp, len, 0); ntt(B, len, 0);
for(int i = 0; i < len; ++i) B[i] = ((2ll * B[i] % Mod - B[i] * B[i] % Mod * tmp[i] % Mod) % Mod + Mod) % Mod;
ntt(B, len, 1);
for(int i = siz; i < len; ++i) B[i] = 0;
}
}
int n, m;
LL a[Maxn], b[Maxn];
int main() {
read(n);
for(int i = 0; i < n; ++i) read(a[i]);
Polynomial::polyinv(a, b, n);
for(int i = 0; i < n; ++i) printf("%lld ", b[i]);
return 0;
}