其实就是分治 FFT
如果已知一个多项式 (F) 的各项系数满足 (F_n=C_nsum_{i=0}^{n-1}F_iG_{n-i}) 那么就可以用半在线卷积求出 (F)。
一个最朴素的算法就是通过将问题分治成两个子问题去解决。考虑在每个区间先递归到左半边,然后处理左边对右边的贡献,最后递归到右边去处理。复杂度为 (T(n)=2T(frac n2)+mathcal O(nlog n)=mathcal O(nlog^2 n))。
考虑优化:我们每次不只把区间分成两份,而是分成 (B) 份,两两之间考虑贡献。即先转成点值,然后在通过它做贡献的时候再转回系数。这样的复杂度是 (T(n)=BT(frac nB)+mathcal O(nB+nlogfrac nB)) 当 (B=log n) 时复杂度为 (mathcal O(frac {nlog^2 n}{loglog n}))。
经典应用是求 (exp),即已知 (F) 的各项系数,求 (G=exp F) 的各项系数。
考虑求导:(G'=F'G),即得 (G_n=frac 1nsum_{k=1}^{n}kF_kG_{n-k}),于是直接计算即可。
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<cmath>
#include<cstdlib>
#include<cassert>
#include<ctime>
typedef long long ll;
typedef std::vector<ll> vec;
const ll mod = 998244353, gen = 3;
const int maxn = 1E+5 + 5;
namespace IObuf {
const int LEN = 1 << 18;
char ibuf[LEN + 5], *p1 = ibuf, *p2 = ibuf;
char obuf[LEN + 5], *p3 = obuf;
inline char get() {
#ifdef ONLINE_JUDGE
return p1 == p2 && (p2 = (p1 = ibuf) + fread(ibuf, 1, LEN, stdin), p1 == p2) ? EOF : *p1++;
#endif
return getchar();
}
inline ll getll(char c = get(), ll x = 0) {
while(c < '0' || c > '9') c = get();
while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = get();
return x;
}
inline char* flush() { fwrite(obuf, 1, p3 - obuf, stdout); return p3 = obuf; }
inline void put(char c) {
#ifdef ONLINE_JUDGE
p3 == obuf + LEN && flush(); *p3++ = c; return;
#endif
putchar(c);
}
char s[32]; int t = 0;
inline void putll(ll x, char suf = ' ') {
if(!x) put('0');
else {
while(x) s[++t] = x % 10 + 48, x /= 10;
while(t) put(s[t--]);
} put(suf);
}
}
using IObuf::getll;
using IObuf::putll;
inline ll fsp(ll a, ll b, ll res = 1) {
for(a %= mod, b %= mod - 1; b; a = a * a % mod, b >>= 1)
b & 1 ? res = res * a % mod : 0; return res;
}
struct poly {
vec f;
// tools
inline void redeg(int d) { f.resize(d + 1); }
inline int deg() { return f.size() - 1; }
inline void print(int d) const {
for(int i = 0; i <= d; ++i)
putll(f[i]);
IObuf::put('
');
}
};
int Len = -1, rev[maxn << 2];
unsigned long long rt[maxn << 3];
inline void NTTpre(int bit) {
for(int i = Len + 1; i <= bit; ++i) {
ll stp = fsp(gen, mod - 1 >> i);
rt[1 << i] = 1;
for(int j = (1 << i) + 1; j < (1 << i + 1); ++j)
rt[j] = rt[j - 1] * stp % mod;
} Len = bit;
}
unsigned long long tmp[maxn << 2];
inline void NTT(poly &f, int bit, int op) {
if(Len < bit) NTTpre(bit);
int N = 1 << bit; f.redeg(std::max(N - 1, f.deg()));
for(int i = 0; i < N; ++i) {
rev[i] = (rev[i >> 1] >> 1 | (i & 1) << bit - 1);
tmp[i] = f.f[rev[i]] + (f.f[rev[i]] >> 31 & mod); // magic
}
for(int len = 1; len < N; len <<= 1) {
for(int i = 0; i < N; i += len << 1) {
for(int k = i, x = len << 1; k < i + len; ++k, ++x) {
unsigned long long g = tmp[k], h = tmp[k + len] * rt[x] % mod;
tmp[k] = g + h, tmp[k + len] = mod + g - h;
}
}
}
for(int i = 0; i < N; ++i) f.f[i] = tmp[i] % mod;
if(op == -1) {
reverse(f.f.begin() + 1, f.f.begin() + N);
ll invN = fsp(N, mod - 2);
for(int i = 0; i < N; ++i)
f.f[i] = f.f[i] * invN % mod;
}
}
ll Inv[maxn << 1];
const int logB = 4, B = 1 << logB;
poly F, res, G[30][B];
inline void exp(int bit, int l, int r) {
if(r - l <= 128) {
r = std::min(r, F.deg() + 1);
for(int i = l; i < r; ++i) {
if(i == 0) res.f[i] = 1;
else res.f[i] = res.f[i] % mod * Inv[i] % mod;
for(int j = i + 1; j < r; ++j)
(res.f[j] += res.f[i] * F.f[j - i]) %= mod;
} return;
}
int mid = l + r >> 1, dif = (r - l) / B;
int N = 1 << bit, L = 0;
poly w[B];
while(L < B) {
if(l + L * dif > F.deg()) break;
w[L++].redeg(dif * 2 - 1);
}
for(int i = 0; i < L; ++i) {
if(i != 0) {
for(int j = 0; j < dif * 2; ++j) w[i].f[j] %= mod;
NTT(w[i], bit - logB + 1, -1);
for(int j = 0; j < dif; ++j)
res.f[l + i * dif + j] += w[i].f[j + dif];
}
exp(bit - logB, l + i * dif, l + (i + 1) * dif);
if(i != L - 1) {
poly H; H.redeg(dif * 2 - 1);
for(int j = 0; j < dif; ++j)
H.f[j] = res.f[j + l + i * dif];
NTT(H, bit - logB + 1, 1);
for(int j = i + 1; j < L; ++j)
for(int k = 0; k < dif * 2; ++k)
w[j].f[k] += H.f[k] * G[bit][j - i - 1].f[k];
}
}
}
int n;
int main() {
n = getll() - 1; F.redeg(n);
for(int i = 0; i <= n; ++i) F.f[i] = getll() * i % mod;
int bit = 0, N = 1; while(N <= n) ++bit, N <<= 1;
res.redeg(N - 1), NTTpre(bit);
for(int b = bit; b >= logB; b -= logB) {
if((1 << b) <= 128) break;
int dif = 1 << (b - logB);
for(int i = 0; i < B - 1; ++i) {
if(dif * i > F.deg()) break;
G[b][i].redeg(dif * 2 - 1);
for(int j = 0; j < dif * 2 && j + dif * i <= F.deg(); ++j)
G[b][i].f[j] = F.f[j + dif * i];
NTT(G[b][i], b - logB + 1, 1);
}
}
Inv[1] = 1;
for(int i = 2; i <= n; ++i)
Inv[i] = Inv[mod % i] * (mod - mod / i) % mod;
exp(bit, 0, N), res.print(n), IObuf::flush();
}
一些优化:
-
区间长度比较小的时候可以暴力算贡献
-
每次都是在原多项式上取一段区间做乘法,这部分可以预处理出点值
-
每次都是长为 (n) 的和长为 (2n) 的卷积,求 ([n,2n)) 项的系数,可以循环卷积优化
做了上述优化后这个算法已经拿到洛谷多项式 exp 的模板题的最快榜了。
多项式 ln 也可以类似地做,现在也已经拿到洛谷的最快榜了。
多项式乘法逆和多项式开根虽然也可以同样地做,但是实际效果并不如单 (log) 的速度快。