前置知识
-
[egin{aligned} (ln x)' &= frac{1}{x} \ (exp x)' &= x \ end{aligned} ]
-
复合函数的求导(链式法则)
[(gcirc f)' (x) = g(f(x))'f'(x) ] -
多项式求逆, 分治FFT.
多项式 ln
将 (ln f(x)) 求导再积分.
[egin{aligned}
frac{mathrm{d} ln f(x)}{mathrm{d} x} &equiv frac{f'(x)}{f(x)} pmod{x ^ n} \
ln f(x) &equiv int mathrm{d} ln f(x) equiv int frac{f'(x)}{f(x)} mathrm{d} x pmod{x^n}
end{aligned}
]
多项式求导, 积分都是 (O(n)) 的, 多项式乘法为 (O(nlog n)), 所以总复杂度为 (O(nlog n)).
多项式 exp
普通方法
和求 ln 一样, 也是求导再积分.
[egin{aligned}
(exp f(x))' &equiv f'(x) exp f(x) pmod{x^n} \
exp f(x) &equiv int f'(x) exp f(x) mathrm{d}x pmod{x^n}
end{aligned}
]
可以用分治FFT解决, 时间复杂度为 (O(n log^2 n)).
牛顿迭代
[[学习笔记] 牛顿迭代]https://www.cnblogs.com/BruceW/p/14079514.html
时间复杂度为 (O(n log n)) ,但由于实现过程中需要求 (ln),所以实际上快不了多少(至少在洛谷的模板上跑得差不多)。
代码
(ln)
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const int _ = (1 << 18) + 7;
const int mod = 998244353;
const int rt = 3;
int n, f[_];
int Pw(int a, int p) {
int res = 1;
while (p) {
if (p & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
p >>= 1;
}
return res;
}
namespace POLY {
int tot, num[_], inv[_], pwrt[2][_], tmp[6][_];
void Init() {
tot = 1; while (tot <= n + n) tot <<= 1;
inv[1] = 1;
for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
pwrt[0][tot] = Pw(rt, (mod - 1) / tot);
pwrt[1][tot] = Pw(pwrt[0][tot], mod - 2);
for (int len = (tot >> 1); len; len >>= 1) {
pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
}
}
void NTT(int *f, int t, bool ty) {
for (int i = 1; i < t; ++i) {
num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
if (i < num[i]) swap(f[i], f[num[i]]);
}
for (int len = 2; len <= t; len <<= 1) {
int gap = len >> 1, w1 = pwrt[ty][len], w, tmp;
for (int i = 0; i < t; i += len) {
w = 1;
for (int j = i; j < i + gap; ++j) {
tmp = (ll)w * f[j + gap] % mod;
f[j + gap] = (f[j] - tmp + mod) % mod;
f[j] = (f[j] + tmp) % mod;
w = (ll)w * w1 % mod;
}
}
}
if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
}
void Mul(int *f, int *g, int *h) {
for (int i = 0; i < tot; ++i) tmp[2][i] = f[i], tmp[3][i] = g[i];
NTT(tmp[2], tot, 0), NTT(tmp[3], tot, 0);
for (int i = 0; i < tot; ++i) h[i] = (ll)tmp[2][i] * tmp[3][i] % mod;
NTT(h, tot, 1);
}
void Inv(int *f, int *h) {
for (int i = 0; i < tot; ++i) h[i] = tmp[1][i] = 0;
h[0] = Pw(f[0], mod - 2), tmp[1][0] = f[0], tmp[1][1] = f[1];
for (int len = 2, t = 4; len < tot; len <<= 1, t = (len << 1)) {
NTT(h, t, 0), NTT(tmp[1], t, 0);
for (int i = 0; i < t; ++i) h[i] = (ll)h[i] * (2 - (ll)h[i] * tmp[1][i] % mod + mod) % mod;
NTT(h, t, 1), NTT(tmp[1], t, 1);
for (int i = len; i < t; ++i) tmp[1][i] = f[i], h[i] = 0;
}
}
void Deriv(int *f, int *h) { for (int i = 0; i < tot - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
void Integ(int *f, int *h) { for (int i = tot - 1; i > 0; --i) h[i] = (ll)f[i - 1] * Pw(i, mod - 2) % mod; h[0] = 0; }
void Ln(int *f, int *h) {
for (int i = 0; i < tot; ++i) tmp[4][i] = f[i];
Inv(f, tmp[4]);
Deriv(f, f);
Mul(f, tmp[4], f);
Integ(f, h);
}
}
int gi() {
int x = 0; char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
return x;
}
int main() {
n = gi();
for (int i = 0; i < n; ++i) f[i] = gi();
POLY::Init();
POLY::Ln(f, f);
for (int i = 0; i < n; ++i) printf("%d ", f[i]); putchar('
');
return 0;
}
(exp)(普通方法)
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long ll;
const int _ = (1 << 18) + 7;
const int mod = 998244353, rt = 3;
int n, g[_], f[_];
int Pw(int a, int p) {
int res = 1;
while (p) {
if (p & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
p >>= 1;
}
return res;
}
namespace POLY {
int tot, num[_], pwrt[2][_], inv[_], tmp[5][_];
void Init() {
tot = 1; while (tot <= n + n) tot <<= 1;
inv[1] = 1;
for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
pwrt[0][tot] = Pw(rt, (mod - 1) / tot);
pwrt[1][tot] = Pw(pwrt[0][tot], mod - 2);
for (int len = (tot >> 1); len; len >>= 1) {
pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
}
}
void NTT(int *f, int t, bool ty) {
for (int i = 1; i < t; ++i) {
num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
if (i < num[i]) swap(f[i], f[num[i]]);
}
for (int len = 2; len <= t; len <<= 1) {
int gap = len >> 1, w1 = pwrt[ty][len];
for (int i = 0, w = 1, tmp; i < t; i += len, w = 1)
for (int j = i; j < i + gap; ++j) {
tmp = (ll)w * f[j + gap] % mod;
f[j + gap] = (f[j] - tmp + mod) % mod;
f[j] = (f[j] + tmp) % mod;
w = (ll)w * w1 % mod;
}
}
if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
}
void Mul(int *f, int *g, int *h, int t) {
memcpy(tmp[1], f, t << 2);
memcpy(tmp[2], g, t << 2);
NTT(tmp[1], t, 0), NTT(tmp[2], t, 0);
for (int i = 0; i < (t << 1); ++i) h[i] = (ll)tmp[1][i] * tmp[2][i] % mod;
NTT(h, t, 1);
}
void dcNTT(int *f, int *g, int t, int l, int r) {
if (t == 1) { f[0] = l ? (ll)f[0] * inv[l] % mod : f[0]; return; }
dcNTT(f, g, t >> 1, l, (l + r) >> 1);
memset(tmp[0] + (t >> 1), 0, t << 1);
memcpy(tmp[0], f, t << 1);
Mul(tmp[0], g, tmp[0], t);
for (int i = (t >> 1); i < t; ++i) f[i] = (f[i] + tmp[0][i - 1]) % mod;
dcNTT(f + (t >> 1), g, t >> 1, (l + r) >> 1, r);
}
void Exp(int *f, int *g) { dcNTT(f, g, tot >> 1, 1, tot >> 1); }
void Deriv(int *f, int *h) { for (int i = 0; i < tot - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n; ++i) scanf("%d", &g[i]);
POLY::Init();
POLY::Deriv(g, g);
f[0] = 1;
POLY::Exp(f, g);
for (int i = 0; i < n; ++i) printf("%d ", f[i]); putchar('
');
return 0;
}
(exp) (牛顿迭代)
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long ll;
const int _ = (1 << 18) + 7;
const int mod = 998244353, rt = 3;
int n, f[_], g[_];
int Pw(int a, int p) {
int res = 1;
while (p) {
if (p & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
p >>= 1;
}
return res;
}
namespace POLY {
int tot, num[_], pwrt[2][_], inv[_];
void Init() {
tot = 1; while (tot <= n + n) tot <<= 1;
inv[1] = 1; for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
pwrt[0][tot] = Pw(rt, (mod - 1) / tot);
pwrt[1][tot] = Pw(pwrt[0][tot], mod - 2);
for (int len = (tot >> 1); len; len >>= 1) {
pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
}
}
void Clear(int *f, int L) { memset(f, 0, L << 3); }
void NTT(int *f, int L, bool ty) {
for (int i = 1; i < L; ++i) {
num[i] = (num[i >> 1] >> 1) | ((i & 1) ? L >> 1 : 0);
if (i < num[i]) swap(f[i], f[num[i]]);
}
for (int len = 2; len <= L; len <<= 1) {
int gap = len >> 1, w1 = pwrt[ty][len];
for (int i = 0, w = 1, tmp; i < L; i += len, w = 1)
for (int j = i; j < i + gap; ++j) {
tmp = (ll)w * f[j + gap] % mod;
f[j + gap] = (f[j] - tmp + mod) % mod;
f[j] = (f[j] + tmp) % mod;
w = (ll)w * w1 % mod;
}
}
if (ty) for (int i = 0; i < L; ++i) f[i] = (ll)f[i] * inv[L] % mod;
}
void Cpy(int *h, int *f, int L) { memcpy(h, f, L << 2); }
void Inv(int *h, int *f, int L) {
int a[_], b[_];
Clear(h, L), Clear(a, L), Clear(b, L);
h[0] = Pw(f[0], mod - 2);
for (int len = 2, t = 4; len <= L; len <<= 1, t <<= 1) {
Cpy(a, f, len), Cpy(b, h, len), NTT(b, t, 0), NTT(a, t, 0);
for (int i = 0; i < t; ++i) b[i] = (ll)b[i] * (2 - (ll)a[i] * b[i] % mod + mod) % mod;
NTT(b, t, 1);
for (int i = (len >> 1); i < len; ++i) h[i] = b[i];
}
}
void Deriv(int *h, int *f, int L) { for (int i = 0; i < L - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
void Integ(int *h, int *f, int L) { for (int i = L - 1; i; --i) h[i] = (ll)f[i - 1] * inv[i] % mod; h[0] = 0; }
void Ln(int *h, int *f, int L) {
int a[_], b[_];
Clear(h, L), Clear(a, L), Clear(b, L);
Deriv(a, f, L), Inv(b, f, L);
NTT(a, L << 1, 0), NTT(b, L << 1, 0);
for (int i = 0; i < (L << 1); ++i) h[i] = (ll)a[i] * b[i] % mod;
NTT(h, L << 1, 1);
Integ(h, h, L);
}
void Exp(int *h, int *f, int L) {
int a[_], b[_], c[_];
Clear(h, L), Clear(a, L), Clear(b, L), Clear(c, L);
h[0] = 1, a[0] = f[0], a[1] = f[1];
for (int len = 2, t = 4; len <= L; len <<= 1, t <<= 1) {
Cpy(c, h, len), Ln(b, h, len), Cpy(a, f, len);
NTT(c, len, 0), NTT(b, len, 0), NTT(a, len, 0);
for (int i = 0; i < len; ++i) c[i] = (ll)c[i] * (1ll - b[i] + a[i] + mod) % mod;
NTT(c, len, 1);
for (int i = (len >> 1); i < len; ++i) h[i] = c[i];
}
}
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n; ++i) scanf("%d", &g[i]);
POLY::Init();
POLY::Exp(f, g, POLY::tot >> 1);
for (int i = 0; i < n; ++i) printf("%d ", f[i]); putchar('
');
return 0;
}