有个结论,原问题可以转化为每次开枪的概率中的分母不变,当射到一个已经死掉的猎人时,就继续开枪,不难发现这样射中第 (i) 个人的概率和原问题一样,设 (W=sumlimits_{i=1}^n w_i),(T) 为已经死掉的猎人的 (w_i) 的和,得:
[largeleft( sum_{j=0}^infty left(frac{T}{W}
ight)^j
ight)frac{w_i}{W}=frac{w_i}{W-T}
]
考虑容斥,枚举一个猎人集合 (S),集合内的猎人都要在 (1) 之后被射死,得答案为:
[largeegin{aligned}
&sum_S(-1)^{|S|}sum_{i=0}^inftyleft( 1-frac{sumlimits_{jin S}w_j+w_1}{W}
ight)frac{w_1}{W}\
=&w_1sum_Sfrac{(-1)^{|S|}}{sumlimits_{jin S}w_j+w_1}\
=&w_1sum_{i=0}^{W-w_1}frac{left[x^i
ight]prodlimits_{i=2}^nleft(1-x^{w_i}
ight)}{i+w_1}\
end{aligned}
]
(prodlimits_{i=2}^nleft(1-x^{w_i} ight)) 用分治和 (NTT) 即可计算。
#include<bits/stdc++.h>
#define maxn 400010
#define p 998244353
#define ls (x<<1)
#define rs (x<<1|1)
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,m,w,sum,ans;
ll rev[maxn],inv[maxn],len[maxn],v[maxn],f[maxn];
int mod(int x)
{
return x>=p?x-p:x;
}
ll qp(ll x,ll y)
{
ll v=1;
while(y)
{
if(y&1) v=v*x%p;
x=x*x%p,y>>=1;
}
return v;
}
int calc(int n)
{
int lim=1;
while(lim<=n) lim<<=1;
for(int i=0;i<lim;++i)
rev[i]=(rev[i>>1]>>1)|((i&1)?lim>>1:0);
return lim;
}
void NTT(ll *a,int lim,int type)
{
for(int i=0;i<lim;++i)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int len=1;len<lim;len<<=1)
{
ll wn=qp(3,(p-1)/(len<<1));
for(int i=0;i<lim;i+=len<<1)
{
ll w=1;
for(int j=i;j<i+len;++j,w=w*wn%p)
{
ll x=a[j],y=w*a[j+len]%p;
a[j]=mod(x+y),a[j+len]=(x-y+p)%p;
}
}
}
if(type==1) return;
for(int i=0;i<lim;++i) a[i]=a[i]*inv[lim]%p;
reverse(a+1,a+lim);
}
void solve(int l,int r,ll *a,int x)
{
if(l==r)
{
for(int i=0;i<=v[l];++i) a[i]=0;
a[0]=1,a[v[l]]=p-1,len[x]=v[l];
return;
}
int mid=(l+r)>>1,lim;
ll f[maxn],g[maxn];
solve(l,mid,f,ls),solve(mid+1,r,g,rs),lim=calc(len[x]=len[ls]+len[rs]);
for(int i=len[ls]+1;i<lim;++i) f[i]=0;
for(int i=len[rs]+1;i<lim;++i) g[i]=0;
NTT(f,lim,1),NTT(g,lim,1);
for(int i=0;i<lim;++i) a[i]=f[i]*g[i]%p;
NTT(a,lim,-1);
for(int i=len[x]+1;i<lim;++i) a[i]=0;
}
int main()
{
read(n),read(w),sum=w;
for(int i=1;i<n;++i) read(v[i]),sum+=v[i];
inv[1]=1;
for(int i=2;i<=2*sum;++i) inv[i]=(p-p/i)*inv[p%i]%p;
solve(1,n-1,f,1);
for(int i=0;i<=sum-w;++i) ans=mod(ans+f[i]*inv[i+w]%p);
printf("%d",(ll)ans*w%p);
return 0;
}