题意
给定 (n) 个球,初始为白色,进行两次操作,每次选择 (m) 个球染黑,给定一个 (m) 次多项式 (F)(仅给出前 (m+1) 项点值),设编号最小的黑球为 (A)((1leqslant Aleqslant n)),求 (F(A-1)) 的期望,对 (998244353) 取模。
(1leqslant n<9988244353,1leqslant mleqslant 10^6)
分析
根据期望的线性性,我们把多项式拆成若干个单项式的和,那么我们设目前的次数为 (c),我们对于 ((A-1)^c) 构造出一个组合意义,找到一个最长的白球前缀,在其中任选 (c) 个球的方案数,列出式子:((f_i) 表示多项式第 (i) 项的系数,(i) 枚举在后面那一步实际上选到的白球数量,(j) 枚举实际上被染黑的球的数量)
[ans_c=f_csum_{i=0}^megin{Bmatrix}c\iend{Bmatrix}i!sum_{j=m}^{2m}{jchoose m}{mchoose m-(j-m)}{nchoose i+j}
]
推一推式子:
[sum_{c=0}^mf_csum_{i=0}^megin{Bmatrix}c\iend{Bmatrix}i!sum_{j=m}^{2m}{jchoose m}{mchoose m-(j-m)}{nchoose i+j}\=sum_{i=0}^mi!sum_{j=m}^{2m}{jchoose m}{mchoose 2m-j}{nchoose i+j}sum_{c=0}^mf_cegin{Bmatrix}c\iend{Bmatrix}
]
我们用通项公式展开斯特林数:
[sum_{i=0}^mi!sum_{j=m}^{2m}{jchoose m}{mchoose 2m-j}{nchoose i+j}sum_{c=0}^mf_cfrac{1}{i!}sum_{k=0}^i{ichoose k}(-1)^{i-k}k^c\=sum_{i=0}^msum_{j=m}^{2m}{jchoose m}{mchoose 2m-j}{nchoose i+j}sum_{k=0}^i{ichoose k}(-1)^{i-k}sum_{c=0}^m f_ck^c\=sum_{T=m}^{3m}{nchoose T}sum_{j=m}^{2m}{jchoose m}{mchoose 2m-j}sum_{k=0}^{T-j}{T-jchoose k}(-1)^{T-j-k}F(k)
]
其中 (F(k)) 表示 (x=k) 的时候的点值,容易发现后面枚举 (k) 的部分可以使用一次卷积算出来,前面枚举 (j) 的部分也可以使用一次卷积算出来,于是复杂度 (O(nlog n))。
代码
#include<stdio.h>
#include<vector>
using namespace std;
const int maxn=1<<22,mod=998244353,G=3,invG=(mod+1)/G;
typedef vector<int>poly;
int n,m,ans,lim;
int p[maxn],inv[maxn],fac[maxn],nfac[maxn],F[maxn];
poly f,g;
inline int read(){
int x=0;
char c=getchar();
for(;c<'0'||c>'9';c=getchar());
for(;c>='0'&&c<='9';c=getchar())
x=x*10+c-48;
return x;
}
inline int C(int a,int b){
return a<b? 0:1ll*fac[a]*nfac[b]%mod*nfac[a-b]%mod;
}
int ksm(int a,int b,int mod){
int res=1;
while(b){
if(b&1)
res=1ll*res*a%mod;
a=1ll*a*a%mod,b>>=1;
}
return res;
}
int getlen(int n){
int lim=1,r=0;
for(;lim<n;lim<<=1,r++);
for(int i=0;i<lim;i++)
p[i]=(p[i>>1]>>1)|((i&1)<<(r-1));
return lim;
}
void NTT(poly &x,int opt){
x.resize(lim);
for(int i=0;i<lim;i++)
if(i<p[i])
swap(x[i],x[p[i]]);
for(int len=2,now=1,p=1;len<=lim;len<<=1,now<<=1,p++){
int w=ksm(opt==1? G:invG,(mod-1)/len,mod);
for(int i=0;i<lim;i+=len)
for(int j=0,mul=1;j<now;j++,mul=1ll*mul*w%mod){
int a=x[i+j],b=1ll*x[now+i+j]*mul%mod;
x[i+j]=(a+b)%mod,x[now+i+j]=(a-b+mod)%mod;
}
}
if(opt==0)
for(int i=0;i<lim;i++)
x[i]=1ll*x[i]*inv[lim]%mod;
}
int main(){
fac[0]=fac[1]=nfac[0]=nfac[1]=inv[1]=1;
for(int i=2;i<maxn;i++)
fac[i]=1ll*fac[i-1]*i%mod,inv[i]=mod-1ll*(mod/i)*inv[mod%i]%mod,nfac[i]=1ll*nfac[i-1]*inv[i]%mod;
scanf("%d%d",&n,&m);
for(int i=0;i<=m;i++)
f.push_back(1ll*read()*nfac[i]%mod),g.push_back((i&1)? (mod-nfac[i]):nfac[i]);
lim=getlen(2*m+2),NTT(f,1),NTT(g,1);
for(int i=0;i<lim;i++)
f[i]=1ll*f[i]*g[i]%mod;
NTT(f,0),f.resize(m+1),g.clear();
for(int i=0;i<=m;i++)
f[i]=1ll*f[i]*fac[i]%mod,g.push_back(1ll*C(m+i,m)*C(m,m+m-(m+i))%mod);
NTT(f,1),NTT(g,1);
for(int i=0;i<lim;i++)
f[i]=1ll*f[i]*g[i]%mod;
NTT(f,0);
for(int i=0,C=1;i<=3*m;i++,C=1ll*C*(n-i+1)%mod*inv[i]%mod)
if(i>=m)
ans=(ans+1ll*C*f[i-m])%mod;
printf("%d
",ans);
return 0;
}