题面
解析
设出现$S$次的颜色至少有$i$种的方案数为$f_i$,钦定$i$种颜色出现$S$次,剩下的任选:$f_i=inom{m}{i}*frac{n!}{(S!)^i(n-iS)!}*(m-i)^{n-iS}$,其中$frac{n!}{(S!)^i(n-iS)!}$表示在$n$个位置种选$n-iS$个位置填$i$种颜色,每种颜色填$S$次的方案数。
设$g_i$表示设出现$S$次的颜色恰好有$i$种的方案数,然后会发现$f_i=sum_{j=i}inom{j}{i}g_j$
于是二项式反演可得:$$egin{align*}g_i&=sum_{j=i}(-1)^{j-i}inom{j}{i}f_j\&=frac{1}{i!}sum_{j=i}frac{(-1)^{j-i}}{(j-i)!}*j!*f_jend{align*}$$
卷积即可
$O(M log M)$
代码:
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> using namespace std; typedef long long ll; const int maxn = 200005, mod = 1004535809, g = 3; inline int read() { int ret, f=1; char c; while((c=getchar())&&(c<'0'||c>'9'))if(c=='-')f=-1; ret=c-'0'; while((c=getchar())&&(c>='0'&&c<='9'))ret=(ret<<3)+(ret<<1)+c-'0'; return ret*f; } int add(int x, int y) { return x + y < mod? x + y: x + y - mod; } int rdc(int x, int y) { return x - y < 0? x - y + mod: x - y; } ll qpow(ll x, int y) { ll ret = 1; while(y) { if(y&1) ret = ret * x % mod; x = x * x % mod; y >>= 1; } return ret; } int n, m, s, lim, bit, rev[maxn<<1]; int fac[10000005], fnv[10000005], a[maxn]; ll ginv, f[maxn<<1], h[maxn<<1]; void init() { int t = max(n, m); ginv = qpow(g, mod - 2); fac[0] = 1; for(int i = 1; i <= t; ++i) fac[i] = 1LL * fac[i-1] * i % mod; fnv[t] = qpow(fac[t], mod - 2); for(int i = t - 1; i >= 0; --i) fnv[i] = 1LL * fnv[i+1] * (i + 1) % mod; } int comb(int x, int y) { if(x < y || y < 0) return 0; return (1LL * fac[x] * fnv[y] % mod) * fnv[x-y] % mod; } void NTT_init(int x) { lim = 1; bit = 0; while(lim <= x) { lim <<= 1; ++ bit; } for(int i = 1; i < lim; ++i) rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (bit - 1)); } void NTT(ll *x, int y) { for(int i = 1; i < lim; ++i) if(i < rev[i]) swap(x[i], x[rev[i]]); ll wn, w, u, v; for(int i = 1; i < lim; i <<= 1) { wn = qpow((y == 1)? g: ginv, (mod - 1) / (i << 1)); for(int j = 0; j < lim; j += (i << 1)) { w = 1; for(int k = 0; k < i; ++k) { u = x[j+k]; v = x[j+k+i] * w % mod; x[j+k] = add(u, v); x[j+k+i] = rdc(u, v); w = w * wn % mod; } } } if(y == -1) { ll linv = qpow(lim, mod - 2); for(int i = 0; i < lim; ++i) x[i] = x[i] * linv % mod; } } int main() { n = read(); m = read(); s = read(); for(int i = 0; i <= m; ++i) a[i] = read(); int sj = min(m, n / s); init(); for(int i = 0; i <= sj; ++i) f[i] = (((1LL * comb(m, i) * fac[n] % mod) * qpow(fnv[s], i) % mod) * fnv[n-i*s] % mod) * qpow(m - i, n - i * s) % mod; for(int i = 0; i <= sj; ++i) { f[i] = f[i] * fac[i] % mod; h[sj-i] = ((i & 1)? rdc(0, fnv[i]): fnv[i]); } NTT_init(sj << 1); NTT(f, 1); NTT(h, 1); for(int i = 0; i < lim; ++i) f[i] = f[i] * h[i] % mod; NTT(f, -1); int ans = 0; for(int i = 0; i <= sj; ++i) ans = add(ans, (a[i] * f[i+sj] % mod) * fnv[i] % mod); printf("%d", ans); return 0; }