生成函数神题 QAQ,orz EI
我的多项式水平是外国人水平
题目链接
简要算法
概率与期望、容斥、生成函数、拉格朗日反演、牛顿迭代
(O(m^2))
(O(m^2)) 做法有很多,如 min-max 容斥,下面介绍一种看上去比较有优化空间的做法。
对于 (Ssubseteq{1,2,dots,m}),定义 (end_S=0/1) 表示 (S) 是否存在 (k) 个连续的数。
考虑每一轮的贡献,第 (x) 轮的贡献就是前 (x) 轮操作之后不会到达终态的概率:
对于 (x) 轮之后选过的数组成的集合恰好为 (S) 的概率,考虑容斥计算:
其中 (is_i) 表示是否存在编号为 (i) 的卡。
设 (w_i(x)=x^{is_i}-1),(G(x)=sum_{end_S=0}prod_{iin S}w_i(x)),则答案为 (sum_{i=0}^{m-1}frac m{m-i}[x^i]G(x))。
由于 (is_i=0) 时 (w_i(x)=0),(is_i=1) 时 (w_i(x)=x-1),故只要先 DP (f_{i,j}) 表示前 (i) 种编号选出 (j) 个均为 (is=1) 的方案数(转移可以容斥掉最后一段长为 (k) 的方案),则 (G(x)=sum_{ige 0}f_{max,i}(x-1)^i),可以直接计算。
Solution by EI
对于 (G(x)=sum_{ige 0}f_{max,i}(x-1)^i) 的每一项,注意到 ([x^i]G(x)=sum_{jge i}(-1)^{j-i}inom jif_{max,j}),可以一次卷积求出。
对于上面的 DP,实际上可以把输入的 (a) 数组排序之后分成一些值域连续段,求出每个连续段((is) 全为 (1))中选出 (0,1,dots) 个元素的方案数,最后用一次分治 NTT (O(mlog^2m)) 求出。
现在要解决的问题就是给定 (n),如何对于每个 (i=0,1,dots,n) 计算出在 (n) 个元素中选出 (i) 个使得没有任意连续的 (k) 个元素被选出的方案数。
可以转化成对于每个 (i=0,1,dots,n) 计算出把 (n+1) 拆分成 (n+1-i) 个不超过 (k) 的正整数之和的方案数。转化方法即为增加一个不能选的元素 (n+1),以所有不选的元素为右端点,把该元素左边有被选上的一段元素并起来作为一段。
也就是对于任意 (1le mle n+1) 求出:
也就是求二元生成函数:
的 (x^{n+1}) 次项。
对于只能求某一项的问题我们通常考虑拉格朗日反演,设 (G(x)=sum_{i=1}^kx_i),(G(x)) 的复合逆为 (G^{-1}(x)),我们有:
由于 (frac u{(1-ux)^2}) 中 (u) 的次数总是比 (x) 的次数多 (1),故如果求出了 ((frac x{G^{-1}(x)})^{n+1}),就能枚举 (frac u{(1-ux)^2}) 中 (x) 的次数计算这两个式子积的第 (n) 项了。
现在要求的就是 (F(x)=G^{-1}(x))。由于 (G(x)=sum_{i=1}^kx_i=frac{x-x^{k+1}}{1-x}),故我们有:
可以牛顿迭代。
总复杂度 (O(mlog^2m)),瓶颈在分治 NTT,但牛顿迭代部分的常数还不止一个 log
Code
#include <bits/stdc++.h>
template <class T>
inline void read(T &res)
{
res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
if (bo) res = ~res + 1;
}
const int N = 3e6 + 5, djq = 998244353;
int n, m, rev[N], yg[N], a[N], b[N], ff, tot, t1[N], t2[N], t3[N], t4[N], inv[N], f[N],
t5[N], t6[N], t7[N], cnt[N], len, fac[N], invf[N], ans;
std::vector<int> A[N];
int qpow(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = 1ll * res * a % djq;
a = 1ll * a * a % djq;
b >>= 1;
}
return res;
}
inline void add(int &a, const int &b) {if ((a += b) >= djq) a -= djq;}
inline void sub(int &a, const int &b) {if ((a -= b) < 0) a += djq;}
void FFT(int n, int *a, int op)
{
for (int i = 0; i < n; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
yg[n] = qpow(1312005, (djq - 1) / n * ((n + op) % n));
for (int i = n >> 1; i; i >>= 1)
yg[i] = 1ll * yg[i << 1] * yg[i << 1] % djq;
for (int k = 1; k < n; k <<= 1)
{
int x = yg[k << 1];
for (int i = 0; i < n; i += k << 1)
{
int w = 1;
for (int j = 0, *f1 = a + i, *f2 = a + i + k; j < k; j++, f1++, f2++)
{
int u = *f1, v = 1ll * w * (*f2) % djq;
add(*f1 = u, v); sub(*f2 = u, v);
w = 1ll * w * x % djq;
}
}
}
if (op == -1)
{
int gg = qpow(n, djq - 2);
for (int i = 0; i < n; i++) a[i] = 1ll * a[i] * gg % djq;
}
}
void nealchen(int n)
{
ff = 1; tot = 0;
while (ff < n) ff <<= 1, tot++;
for (int i = 0; i < ff; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << tot - 1);
}
void getinv(int n, int *a, int *b)
{
b[0] = 1;
for (int k = 1; k <= n; k <<= 1)
{
nealchen(k << 2);
for (int i = k; i < ff; i++) b[i] = 0;
for (int i = 0; i < ff; i++) t1[i] = i <= n && i < (k << 1) ? a[i] : 0;
FFT(ff, b, 1); FFT(ff, t1, 1);
for (int i = 0; i < ff; i++) b[i] = (2ll - 1ll * t1[i] * b[i] % djq
+ djq) * b[i] % djq;
FFT(ff, b, -1);
}
}
void getln(int n, int *a, int *b)
{
getinv(n, a, t2); b[n] = 0; nealchen(n << 1 | 1);
for (int i = 1; i <= n; i++) b[i - 1] = 1ll * i * a[i] % djq;
for (int i = n + 1; i < ff; i++) t2[i] = b[i] = 0;
FFT(ff, b, 1); FFT(ff, t2, 1);
for (int i = 0; i < ff; i++) b[i] = 1ll * b[i] * t2[i] % djq;
FFT(ff, b, -1);
for (int i = n; i >= 1; i--) b[i] = 1ll * b[i - 1] * inv[i] % djq;
b[0] = 0;
}
void getexp(int n, int *a, int *b)
{
b[0] = 1;
for (int k = 1; k <= n; k <<= 1)
{
for (int i = k; i < (k << 2); i++) b[i] = 0;
getln((k << 1) - 1, b, t3); nealchen(k << 2);
for (int i = 0; i < ff; i++)
{
if (i >= (k << 1)) {t2[i] = 0; continue;}
t2[i] = i <= n ? a[i] : 0; sub(t2[i], t3[i]); if (!i) add(t2[i], 1);
}
FFT(ff, b, 1); FFT(ff, t2, 1);
for (int i = 0; i < ff; i++) b[i] = 1ll * b[i] * t2[i] % djq;
FFT(ff, b, -1);
}
}
void getpow(int n, int k, int *a, int *b)
{
getln(n, a, t4);
for (int i = 0; i <= n; i++) t4[i] = 1ll * t4[i] * k % djq;
getexp(n, t4, b);
}
void calc(int n, int *a)
{
for (int i = 0; i <= n; i++) t5[i] = f[i];
getpow(n, n + 1, t5, t6);
for (int i = 0; i <= n; i++)
a[n - i] = 1ll * inv[n + 1] * (i + 1) % djq * t6[n - i] % djq;
}
std::vector<int> polymul(std::vector<int> a, std::vector<int> b)
{
int n = a.size(), m = b.size(); nealchen(n + m - 1);
for (int i = 0; i < ff; i++) t1[i] = i < n ? a[i] : 0, t2[i] = i < m ? b[i] : 0;
FFT(ff, t1, 1); FFT(ff, t2, 1);
for (int i = 0; i < ff; i++) t1[i] = 1ll * t1[i] * t2[i] % djq;
FFT(ff, t1, -1); std::vector<int> res;
for (int i = 0; i < n + m - 1; i++) res.push_back(t1[i]);
return res;
}
std::vector<int> nealchen2003(int l, int r)
{
if (l == r) return A[l];
int mid = l + r >> 1;
return polymul(nealchen2003(l, mid), nealchen2003(mid + 1, r));
}
int main()
{
read(n); read(m); inv[1] = f[0] = fac[0] = invf[0] = 1;
for (int i = 2; i <= n + 1; i++)
inv[i] = 1ll * (djq - djq / i) * inv[djq % i] % djq;
for (int k = 1; k <= n; k <<= 1)
{
getpow((k << 1) - 1, m, f, t5); nealchen(k << 2);
for (int i = k << 1; i < ff; i++) t5[i] = 0;
for (int i = 0; i < ff; i++) t6[i] = f[i], t7[i] = t5[i];
FFT(ff, t6, 1); FFT(ff, t7, 1);
for (int i = 0; i < ff; i++) t6[i] = 1ll * t6[i] * t7[i] % djq;
FFT(ff, t6, -1);
for (int i = k << 1; i < ff; i++) t6[i] = 0;
for (int i = (k << 1) - 1; i >= 0; i--)
t6[i] = i >= m ? (djq - t6[i - m]) % djq : 0,
t5[i] = i >= m ? (1ll * djq * djq - 1ll * (m + 1) * t5[i - m]) % djq : 0;
add(t5[1], 1); add(t5[0], 1); sub(t6[0], 1);
for (int i = 0; i < k; i++) add(t6[i + 1], f[i]), add(t6[i], f[i]);
getinv((k << 1) - 1, t5, t7); nealchen(k << 2);
for (int i = k << 1; i < ff; i++) t6[i] = t7[i] = 0;
FFT(ff, t6, 1); FFT(ff, t7, 1);
for (int i = 0; i < ff; i++) t6[i] = 1ll * t6[i] * t7[i] % djq;
FFT(ff, t6, -1);
for (int i = 0; i < (k << 1); i++) sub(f[i], t6[i]);
}
getinv(n, f, t5); for (int i = 0; i <= n; i++) f[i] = t5[i];
for (int i = 1; i <= n; i++) read(a[i]); std::sort(a + 1, a + n + 1);
for (int i = 1; i <= n; i++)
{
if (i == 1 || a[i] > a[i - 1] + 1) len++;
cnt[len]++;
}
for (int i = 1; i <= len; i++)
{
calc(cnt[i], t7);
for (int j = 0; j <= cnt[i]; j++) A[i].push_back(t7[j]);
}
std::vector<int> nc = nealchen2003(1, len);
for (int i = 0; i <= n; i++) t1[i] = nc[i];
for (int i = 1; i <= n; i++) fac[i] = 1ll * fac[i - 1] * i % djq,
invf[i] = 1ll * invf[i - 1] * inv[i] % djq;
for (int i = 0; i <= n; i++)
{
t1[i] = 1ll * t1[i] * fac[i] % djq;
if (t2[n - i] = invf[i], i & 1) t2[n - i] = djq - t2[n - i];
}
nealchen(n << 1 | 1);
for (int i = n + 1; i < ff; i++) t1[i] = t2[i] = 0;
FFT(ff, t1, 1); FFT(ff, t2, 1);
for (int i = 0; i < ff; i++) t1[i] = 1ll * t1[i] * t2[i] % djq;
FFT(ff, t1, -1);
for (int i = 0; i < n; i++) ans = (1ll * invf[i] * t1[n + i]
% djq * n % djq * inv[n - i] + ans) % djq;
return std::cout << ans << std::endl, 0;
}