「CTS2019 | CTSC2019」氪金手游
传送门
题解
首先我们可以对这个((u_i,v_i))分析,不难发现这是一个树形结构(如果不考虑顺序).
我们先假定他是一个外向树,设(S)表示(sum_{i}W_i),(s_u)表示(u)的子树的(sum W),那么对于这个(u),合法的概率为:
[frac{W_u}{S}sum_{i=0}^{infty}(frac{S-s_u}{S})^i
]
即我们考虑(u)这张卡第几次出.
这个东西用无穷等比数列求和化简可以得到:(frac{W_u}{s_u})
所以此时每一个子树只和子树内的(W)相关,又有(W in {1,2,3}),所以我们可以:
设(f_{u,i})表示以(u)为根的子树,(s_u=i)的合法的概率,合并就是一个卷积.
这个时候如果有树上的反向边,我们考虑容斥,若有(j)个反向边不合法,即(sum_{j,k}-1^{j}f_{1,k,j})
这个反向边不合法其实对应的就是变成正向边,然后删除这条边.
即
[T_u>T_v,T_{fa}<T_{u}
rightarrow T_{fa}<T_{v}
]
也就是这个子树和上面的节点无关,所以要删除这个子树,即只把这个子树合并到当前根节点.
代码
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<iostream>
using namespace std;
#define ll long long
#define REP(a,b,c) for(int a=b;a<=c;a++)
#define re register
#define file(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout)
typedef pair<int,int> pii;
#define mp make_pair
inline int gi()
{
int f=1,sum=0;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
return f*sum;
}
const int N=3010,Mod=998244353;
int f[N][N],front[N],cnt,siz[N],tmp[N],n;
struct node{int to,nxt;}e[N<<1];
void Add(int u,int v){e[++cnt]=(node){v,front[u]};front[u]=cnt;}
int qpow(int a,int b){int ret=1;while(b){if(b&1)ret=1ll*ret*a%Mod;b>>=1;a=1ll*a*a%Mod;}return ret;}
void dfs(int u,int ff)
{
siz[u]=1;
for(int i=front[u];i;i=e[i].nxt)
{
int v=e[i].to;if(v==ff)continue;
dfs(v,u);
for(int j=0;j<=siz[u]*3;j++)
for(int k=0;k<=siz[v]*3;k++)
{
int val=1ll*f[u][j]*f[v][k]%Mod;
if(i&1)tmp[j+k]=(tmp[j+k]+val)%Mod;
else tmp[j+k]=(tmp[j+k]-val+Mod)%Mod,tmp[j]=(tmp[j]+val)%Mod;
}
siz[u]+=siz[v];
for(int j=0;j<=siz[u]*3;j++)f[u][j]=tmp[j],tmp[j]=0;
}
for(int i=0;i<=siz[u]*3;i++)
f[u][i]=1ll*f[u][i]*qpow(i,Mod-2)%Mod;
}
int main()
{
n=gi();
for(int i=1;i<=n;i++)
{
int a=gi(),b=gi(),c=gi();
int s=qpow(a+b+c,Mod-2);
f[i][1]=1ll*a*s%Mod;
f[i][2]=2ll*b*s%Mod;
f[i][3]=3ll*c*s%Mod;
}
for(int i=1;i<n;i++)
{
int u=gi(),v=gi();
Add(u,v);Add(v,u);
}
dfs(1,1);
int ans=0;
for(int i=0;i<=n*3;i++)ans=(ans+f[1][i])%Mod;
printf("%d
",ans);
return 0;
}