这道题是一道树形dp,设f[i][j]表示到第i个点,此时权值和为j的期望值。
一开始打的是n^3的暴力,但是因为数据水所以拿了60分。
这道题的正解就是暴力优化,树形依赖dp。
因为选每一个点的条件是选了这个点到根节点的所有点,所以可以在dp到这个点的儿子之前,将这个点的dp值传下去,因为此时相当于是只多增加了一个点,所以可以O(k)做。
做完这个点以后,再将这个点的dp值传回去,同时计算断掉这条边的情况,这个也是O(k)的。
#include<cstdio> #include<cstring> #include<algorithm> #define N 5010 #define K 5010 #define ll long long #define mo 998244353 #define er 499122177 using namespace std; int n,k,i,f[N][K],q[N],sum[N],x,y,tot,g[N][K]; ll temp; struct edge{ int to,next; }e[N*2]; void insert(int x,int y){ tot++; e[tot].to=y; e[tot].next=q[x]; q[x]=tot; } void dfs(int x,int father){ int i,j,l,y; for (i=q[x];i;i=e[i].next){ y=e[i].to; if (y!=father){ for (j=0;j<=k;j++){ if (j+sum[y]<=k) temp=1ll*f[x][j]*er%mo,f[y][j+sum[y]]=temp; } dfs(y,x); for (j=0;j<=k;j++) temp=1ll*(f[y][j]+1ll*f[x][j]*er%mo)%mo,f[x][j]=temp; } } } int main(){ freopen("luge.in","r",stdin); freopen("luge.out","w",stdout); scanf("%d%d",&n,&k); for (i=1;i<=n;i++) scanf("%d",&sum[i]); for (i=1;i<n;i++){ scanf("%d%d",&x,&y); insert(x,y); insert(y,x); } f[1][sum[1]]=1; dfs(1,0); printf("%d ",f[1][k]); return 0; }