思路太清奇了……
考虑容斥,即枚举至少有哪几个是在(1)号之后被杀的。设(A=sum_{i=1}^nw_i),(S)为那几个在(1)号之后被杀的人的(w)之和。关于杀了人之后分母的变化,我们可以假设这个人被杀之后还活着(说好的人被杀就会死呢),不过如果选到了它要再选一次,这个和之前的是等价的。于是这几个人在(1)之后被杀的概率为$$P=sum_{i=0}^infty (1-frac{S+w_1}{A})^ifrac{w_1}{A}$$
[P=frac{w_1}{A}sum_{i=0}^infty (1-frac{S+w_1}{A})^i
]
[P=frac{w_1}{A} imes frac{1}{1-1+frac{S+w_1}{A}}
]
[P=frac{w_1}{S+w_1}
]
直接暴力枚举不现实,由于(sum w_ileq 10^5),所以我们可以把每个(S)的系数,即每个(S)出现了多少次给求出来,然后直接计算
由于(S)是一堆(w_i)乘起来的,而且因为容斥系数所以每多乘上一个(w_i)就要变一次号,所以我们可以把每一个(w_i)写成生成函数的形式(1-x^{w_i}),然后求出(prod_{i=2}^n(1-x^{w_i})),那么(S)的系数就是(x^S)
然后它这个分治(NTT)的意思大概就是……如果我们直接把这几个多项式乘起来复杂度是(O(nmlog m))的(其中(m)为最大的次数),因为所有多项式的次数之和为(m),我们可以把多项式两两合并,那么每一次多项式的个数都会减少一半,于是总的层数为(O(log n)),又因为每一层多项式的次数之和为(m),于是每一层的时间复杂度都是(O(mlog m)),那么总的时间复杂度就是(O(mlog nlog m))
据说还有用生成函数的乱七八糟的姿势以及多项式(exp)做到(O(mlog m))的,然而我不会啊2333
//minamoto
#include<bits/stdc++.h>
#define R register
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
R int res,f=1;R char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
return res*f;
}
double readdb()
{
R double x=0,y=0.1,f=1;R char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(x=ch-'0';(ch=getc())>='0'&&ch<='9';x=x*10+ch-'0');
for(ch=='.'&&(ch=getc());ch>='0'&&ch<='9';x+=(ch-'0')*y,y*=0.1,ch=getc());
return x*f;
}
char sr[1<<21],z[20];int C=-1,Z=0;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
void print(R int x){
if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x;
while(z[++Z]=x%10+48,x/=10);
while(sr[++C]=z[Z],--Z);sr[++C]='
';
}
const int N=5e5+5,P=998244353,Gi=332748118;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
R int res=1;
for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);
return res;
}
int A[35][N],O[N],r[N],w[N],deg[N];
int n,m,tot,sum,ans;
void NTT(int *A,int ty,int lim){
fp(i,0,lim-1)if(i<r[i])swap(A[i],A[r[i]]);
for(R int mid=1;mid<lim;mid<<=1){
R int I=(mid<<1),Wn=ksm(ty==1?3:Gi,(P-1)/I);O[0]=1;
fp(i,1,mid-1)O[i]=mul(O[i-1],Wn);
for(R int j=0;j<lim;j+=I)for(R int k=0;k<mid;++k){
int x=A[j+k],y=mul(O[k],A[j+k+mid]);
A[j+k]=add(x,y),A[j+k+mid]=dec(x,y);
}
}if(ty==-1)for(R int i=0,inv=ksm(lim,P-2);i<lim;++i)A[i]=mul(A[i],inv);
}
void solve(int ql,int qr){
if(ql==qr){
++tot,A[tot][0]=1,A[tot][w[ql]]=-1,deg[tot]=w[ql];
fp(i,1,w[ql]-1)A[tot][i]=0;
return;
}int mid=(ql+qr)>>1;solve(ql,mid),solve(mid+1,qr);
int lim=1,l=0,x=tot-1,y=tot,m=deg[x]+deg[y];
while(lim<=m)lim<<=1,++l;
fp(i,0,lim-1)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
fp(i,deg[x]+1,lim-1)A[x][i]=0;
fp(i,deg[y]+1,lim-1)A[y][i]=0;
NTT(A[x],1,lim),NTT(A[y],1,lim);
fp(i,0,lim-1)A[x][i]=mul(A[x][i],A[y][i]);
NTT(A[x],-1,lim);
--tot,deg[tot]=m;
}
int main(){
// freopen("testdata.in","r",stdin);
n=read();fp(i,1,n)w[i]=read(),sum+=w[i];
sum-=w[1];
if(n==1)return puts("1"),0;
solve(2,n);
fp(i,0,sum)ans=add(ans,mul(w[1],mul(A[1][i],ksm(i+w[1],P-2))));
printf("%d
",ans);return 0;
}