题目:
分析:
对于t=0的点,显然暴力会重复计算很多。设dis为一个点以下,所有儿子到它的val*dis和,sum为子树权值和,利用父亲已知的b与计算出的dis和sum来计算儿子的值。(必须要用bfs 保证更新的顺序)
(当然我是做麻烦了的)
而t=1,要求知道b反推a,那么对一个点的b值进行分析,联立列方程去解a。
这里用到了一个重要的性质:根的b为∑i=2~n sum[i] (仔细想想可知)
然后一下就是关于如何解a的过程:
#include<bits/stdc++.h> using namespace std; #define N 100005 #define ll long long int to[N<<1],nex[N<<1],tot=0,head[N],fa[N],vis[N]; ll sumer=0,sum[N],dis[N],a[N],b[N],total=0; ll read() { ll x=0; int fl=1; char ch=getchar(); while(ch>'9'||ch<'0') { if(ch=='-') fl=-1; ch=getchar(); } while(ch<='9'&&ch>='0') { x=x*10+ch-'0'; ch=getchar(); } return x*fl; } void add(int a,int b) { to[++tot]=b; nex[tot]=head[a]; head[a]=tot; } void dfs1(int u) { sum[u]=a[u]; for(int i=head[u];i;i=nex[i]){ int v=to[i]; if(v==fa[u]) continue; fa[v]=u; dfs1(v); sum[u]+=sum[v]; dis[u]+=dis[v]+sum[v]; } } void bfs() { queue<int> q; memset(vis,0,sizeof(vis)); q.push(1); for(int i=head[1];i;i=nex[i]){ int v=to[i]; b[1]+=dis[v]+sum[v]; } while(!q.empty()){ int u=q.front(); q.pop(); if(vis[u]) continue; vis[u]=1; for(int i=head[u];i;i=nex[i]){ int v=to[i]; if(v==fa[u]) continue; b[v]=dis[v]+b[u]-(dis[v]+sum[v])+(sum[1]-sum[v]); q.push(v); } } } void dfs2(int u) { for(int i=head[u];i;i=nex[i]){ int v=to[i]; if(v==fa[u]) continue; fa[v]=u; sumer+=b[u]-b[v]; dfs2(v); } } void dfs3(int u) { ll tmp=0; for(int i=head[u];i;i=nex[i]){ int v=to[i]; if(v==fa[u]) continue; sum[v]=(b[u]-b[v]+total)/2; tmp+=sum[v]; dfs3(v); } a[u]=sum[u]-tmp; } void clear(int n) { tot=0; for(int i=1;i<=n*2;i++) head[i]=0,sum[i]=0,dis[i]=0,nex[i]=0,to[i]=0,b[i]=0,a[i]=0; } int main() { freopen("single.in","r",stdin); freopen("single.out","w",stdout); int T,aa,bb,t,n=0; T=read(); while(T--){ clear(n); n=read(); for(int i=1;i<=n-1;i++){ aa=read(); bb=read(); add(aa,bb); add(bb,aa); } t=read(); if(t==0){ for(int i=1;i<=n;++i) a[i]=read(); dfs1(1); bfs(); for(int i=1;i<=n;++i) printf("%lld ",b[i]); printf(" "); } else{ sumer=0; for(int i=1;i<=n;++i) b[i]=read(); dfs2(1);//printf("%lld ",sumer); total=(2*b[1]-sumer)/(n-1); sum[1]=total;//printf("%lld ",total); dfs3(1);// for(int i=1;i<=n;++i) printf("%lld ",a[i]); printf(" "); } } } /* 100 8 1 2 1 3 2 4 2 5 2 6 4 7 4 8 1 16 17 21 28 30 30 0 3 3 5 2 1 1 1 7000002008 6000001005 10000003011 9000000004 9000002008 11000002006 12000001007 14000001003 0 1000000000 1000000000 1000000000 1000 1000000000 1 1000000000 2 6 1 2 1 3 2 4 2 5 2 6 1 13 10 22 21 21 21 0 3 5 2 1 1 1 3 4 0 3 5 2 1 1 1 */