Description
猎人杀是一款风靡一时的游戏“狼人杀”的民间版本,他的规则是这样的:
一开始有 n个猎人,第 i 个猎人有仇恨度 wi。每个猎人只有一个固定的技能:死亡后必须开一枪,且被射中的人也会死亡。
然而向谁开枪也是有讲究的,假设当前还活着的猎人有([i_1...i_m]),那么有(w_{i_k}over sumlimits_{j=1}^{m} w_{i_j})的概率是向猎人(i_k) 开枪
一开始第一枪由你打响,目标的选择方法和猎人一样(即有(w_{i}over sumlimits_{j=1}^{m} w_{j})的概率射中第i个猎人)。由于开枪导致的连锁反应,所有猎人最终都会死亡,现在1号猎人想知道它是最后一个死的的概率。
对998244353取模
(w_i>0,sum w_ileq 100000)
Solution
首先有结论,我们假设可以对已经死亡的猎人开枪,对已经死亡猎人开枪之后继续开枪,那么问题是等价的。
这样就好做不少,因为每个人中枪的概率就固定了。
根据这个结论,我们来推一波式子。
我们可以将整个开枪过程看做是一个序列,每个数可以出现多次,每个数出现有概率,题目问的是1出现时其他所有数都已经出现过的概率。
考虑指数型生成函数,设(t=sum w_k),容易得出除1号外i号猎人的EGF是$$sumlimits_{j>0}{w_i^jx^jover t^ji!}=e^{w_ixover t}-1$$
那么将这些猎人拼接,总的式子就是$$prodlimits_{k=2}^{n}(e^{w_kxover t}-1)$$
假设有3个猎人,2,3号猎人拼在一起就是(e^{(w_2+w_3)xover t}-e^{w_2xover t}-e^{w_3xover t}+1)
对于每个EGF,它对总概率的贡献就是其系数之和
对于(e^{px}),将其系数求和(不考虑阶乘),就是等比数列求和的形式,可以得出和就是(1over 1-p)
那么对于上面的式子,一样计算和,然后加到一起,最后再乘上(w_1/t)(最后一次要选上1号)
现在问题的关键就是要算上面的乘积的每一项(e^{px},pin[0,t])的系数
我们可以把每个(e^{px})也看做多项式的一项,因为同是指数相加,可以构造多项式(x^{w_kover t}-1),那么$$prodlimits_{k=2}^{n}(x^{w_kover t}-1)$$
的每一项(x^{p})前的系数就是原式中每一个(e^{px})的系数
可以先不看t,用分治NTT做,最后再算上。
总复杂度(O(nlog^2 n))
Code
#include <cstdio>
#include <iostream>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(int i=a;i>=b;--i)
#define M 262144
#define L 18
#define mo 998244353
#define LL long long
#define N 100005
using namespace std;
LL wi[M+1],wg[M+1],a[M+1],b[M+1],c[M+1],ny,w[N];
int a1[N],bit[M+1],sz[N],n1,n,sm[N],l2[M+1],cf[L+1],sum;
LL ksm(LL k,LL n)
{
LL s=1;
for(;n;n>>=1,k=k*k%mo) if(n&1) s=s*k%mo;
return s;
}
void prp(int num)
{
fo(i,0,num-1) bit[i]=(bit[i>>1]>>1)|((i&1)<<(l2[num]-1));
fo(i,0,num) wi[i]=wg[M/num*i];
ny=ksm(num,mo-2);
}
void NTT(LL *a,bool pd,int num)
{
LL v,w;
fo(i,0,num-1) if(i<bit[i]) swap(a[i],a[bit[i]]);
for(int m=2,lim=num>>1,half=1;m<=num;half=m,m<<=1,lim>>=1)
{
fo(i,0,half-1)
{
w=(!pd)?wi[i*lim]:wi[num-i*lim];
for(int j=i;j<num;j+=m)
{
v=a[j+half]*w%mo;
a[j+half]=(a[j]-v+mo)%mo;
a[j]=(a[j]+v)%mo;
}
}
}
if(pd) fo(i,0,num-1) a[i]=a[i]*ny%mo;
}
void doit(int l,int r)
{
if(l==r) return;
int mi=sm[n],mid=l;
fo(j,l,r-1) if(max(sm[j]-sm[l-1],sm[r]-sm[j])<mi) mi=max(sm[j]-sm[l-1],sm[r]-sm[j]),mid=j;
doit(l,mid),doit(mid+1,r);
int num=cf[l2[sz[mid+1]+sz[l]+1]];
prp(num);
fo(i,0,num-1) b[i]=c[i]=0;
fo(i,0,sz[l]) b[i]=a[a1[l]+i];
fo(i,0,sz[mid+1]) c[i]=a[a1[mid+1]+i];
NTT(b,0,num),NTT(c,0,num);
fo(i,0,num-1) b[i]=b[i]*c[i]%mo;
NTT(b,1,num);
sz[l]+=sz[mid+1];
fo(i,0,sz[l]) a[a1[l]+i]=b[i];
}
int main()
{
cin>>n;
int l=-1;
cf[0]=1;
fo(i,1,18) cf[i]=(cf[i-1]<<1),l2[cf[i]]=i;
fod(i,M-1,2) if(!l2[i]) l2[i]=l2[i+1];
fo(i,1,n)
{
int c;
scanf("%d",&w[i]);
c=w[i],sum+=c;
if(i!=1)
{
a1[i]=++l;
a[l]=mo-1;
l+=c;
a[l]=1,sz[i]=c,sm[i]=sz[i]+sm[i-1];
}
}
wg[0]=1;
LL v=ksm(3,(mo-1)/M);
fo(i,1,M) wg[i]=wg[i-1]*v%mo;
doit(2,n);
LL ans=0;
fo(i,0,sm[n])
ans=(ans+a[i]*(LL)sum%mo*ksm(sum-i,mo-2)%mo+mo)%mo;
printf("%lld
",ans*w[1]%mo*(LL)ksm(sum,mo-2)%mo);
}