昨天早知道先做这题的。。
随便写写就能过
/* 贪心:结点按a升序排序依次处理 s[u]表示u子树b=1结点个数 t[u]表示u子树c=1节点个数 dif[u]表示u子树bc不同的个数 vis[u]表示u所有子孙被访问了 每次遍历到u,先求一次s,t,dif 如果s[u]=t[u],那么u的贡献就是a[u]*dif[u] 反之就是a[u]*(dif[u]-abs(s[u]-t[u])) 然后给u子树的所有子孙标记 */ #include<bits/stdc++.h> using namespace std; #define N 200005 #define ll long long vector<int>G[N]; ll n,a[N],b[N],c[N],s[N],t[N],dif[N],vis[N],id[N],fa[N]; int cmp(int x,int y){return a[x]<a[y];} void dfs(int u,int pre){ if(b[u])s[u]++; if(c[u])t[u]++; if(b[u]!=c[u])dif[u]++; fa[u]=pre; for(auto v:G[u]) if(v!=pre){ dfs(v,u); s[u]+=s[v]; t[u]+=t[v]; dif[u]+=dif[v]; } } int getdif(int u,int pre){ if(vis[u])return dif[u]; int res=(b[u]!=c[u]); for(auto v:G[u]) if(v!=pre) res+=getdif(v,u); return res; } void dfs2(int u,int pre){ vis[u]=1; for(auto v:G[u])if(!vis[v]) if(v!=pre)dfs2(v,u); } ll sum; int main(){ int n;cin>>n; for(int i=1;i<=n;i++)scanf("%lld%lld%lld",&a[i],&b[i],&c[i]),id[i]=i; for(int i=1;i<n;i++){ int u,v;scanf("%lld%lld",&u,&v); G[u].push_back(v); G[v].push_back(u); } sort(id+1,id+1+n,cmp); dfs(1,1); if(s[1]!=t[1]){puts("-1");return 0;} for(int j=1;j<=n;j++){ int i=id[j]; if(vis[i])continue; int diff=getdif(i,fa[i]); sum+=a[i]*(diff-abs(s[i]-t[i])); dif[i]=abs(s[i]-t[i]); dfs2(i,fa[i]); } cout<<sum<<' '; }