「集训队作业2018」喂鸽子
设 (F(n)) 表示有 (n) 只鸽子,每次等概率选一只喂,期望喂饱第一只鸽子的时间,(f_{n,m}) 表示有 (n) 只鸽子,已经喂了 (m) 次,此时这 (n) 只鸽子中没有鸽子被喂饱的概率。
[Ans = sum_{i=1}^n (-1)^{i+1}{n choose i} F(i) \
F(n)=sum_{igeq 0}sum_{j leq i}{ichoose j}f_{n,j} (frac{N-n}{N})^{i-j}\
=sum_{igeq 0}f_{n,i}sum_{j geq 0} {i+jchoose i}(frac{N-n}{N})^{j}
]
注意到有
[(dfrac{1}{1-x})^n=sum_{igeq 0} {n+i-1choose n-1}x^i
]
所以
[F(n)=sum_{igeq 0}f_{n,i}sum_{j geq 0} {i+jchoose i}(frac{N-n}{N})^{j} \
=sum_{i geq 0}f_{n,i}(frac{N}{n})^{i+1}
]
题外话:我们要求的实际上是恰有 (i) 次分配到这个集合的所有方案里分配到集合外的概率之和,不太好组合意义理解,如果有大爷会组合意义了教教我啊。
接下来求 (f_{n,i}) ,这里由于之前没有考虑分配 (n) 中玉米的概率,所以在这里最后一只鸽子分配到一个玉米的概率是 (frac{1}{N}) 。
[f_{n,i}=frac{1}{N^i}[x^i] (sum_{j=0}^{k-1}frac{x^i}{i!})^n
]
或者也可以递推
[f_{n,i} =sum_{j=0}^{i}{ichoose j}f_{n-1,i-j}frac{1}{N^j}
]
这样做复杂度是 (mathcal O(n^2klog(nk))) 的,还有复杂度为 (mathcal O(n^2k)) 的高妙做法,以后再补。
code
/*program by mangoyang*/
#include <bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 55, K = 1005, mod = 998244353, G = 3;
int f[N][1<<17], js[1<<17], a[1<<17], inv[1<<17], n, k, ans;
inline void up(int &x, int y){
x = x + y >= mod ? x + y - mod : x + y;
}
inline int Pow(int a, int b){
int ans = 1;
for(; b; b >>= 1, a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
namespace poly{
int rev[1<<17], len, lg;
inline void timesinit(int lenth){
for(len = 1, lg = 0; len <= lenth; len <<= 1, lg++);
for(int i = 0; i < len; i++)
rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (lg - 1));
}
inline void dft(int *a, int sgn){
for(int i = 0; i < len; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int k = 2; k <= len; k <<= 1){
int w = Pow(G, (mod - 1) / k);
if(sgn == -1) w = Pow(w, mod - 2);
for(int i = 0; i < len; i += k){
int now = 1;
for(int j = i; j < i + (k >> 1); j++){
int x = a[j], y = 1ll * a[j+(k>>1)] * now % mod;
a[j] = x + y >= mod ? x + y - mod : x + y;
a[j+(k>>1)] = x - y < 0 ? x - y + mod : x - y;
now = 1ll * now * w % mod;
}
}
}
if(sgn == -1){
int Inv = Pow(len, mod - 2);
for(int i = 0; i < len; i++) a[i] = 1ll * a[i] * Inv % mod;
}
}
}
inline int C(int x, int y){
return 1ll * js[x] * inv[y] % mod * inv[x-y] % mod;
}
int main(){
read(n), read(k), js[0] = inv[0] = 1;
for(int i = 1; i <= n * k; i++){
js[i] = 1ll * js[i-1] * i % mod;
inv[i] = Pow(js[i], mod - 2);
}
f[0][0] = 1;
for(int i = 1; i <= n; i++){
for(int j = 0; j <= i * (k - 1); j++){
f[i-1][j] = 1ll * f[i-1][j] * inv[j] % mod;
if(j < k) a[j] = 1ll * inv[j] * Pow(Pow(n, mod - 2), j) % mod;
}
poly::timesinit(i * k);
poly::dft(f[i-1], 1), poly::dft(a, 1);
for(int j = 0; j < poly::len; j++)
f[i][j] = 1ll * f[i-1][j] * a[j] % mod;
poly::dft(f[i], -1);
for(int j = 0; j < poly::len; j++) a[j] = 0;
for(int j = i * (k - 1) + 1; j < poly::len; j++) f[i][j] = 0;
for(int j = 0; j <= i * (k - 1); j++){
f[i][j] = 1ll * f[i][j] * js[j] % mod;
int tmp = 1ll * f[i][j] * C(n, i) % mod;
tmp = 1ll * tmp * Pow(n, j + 1) % mod;
tmp = 1ll * tmp * Pow(Pow(i, mod - 2), j + 1) % mod;
if(i & 1) up(ans, tmp); else up(ans, mod - tmp);
}
}
cout << ans << endl;
return 0;
}