本题前面的操作别的博客里都有。难点在于颜色ci的贡献,如何一次dfs求出答案
先来考虑如何在一次dfs中单独对颜色i进行计算
用遍历dfs序的方式,在深搜过程中,碰到带有颜色 i 的点 u,u每个颜色不为i的子节点v都会贡献一个联通块,
v的贡献的联通块大小是size[v]-sum{v中层次最高的以颜色i的结点为根的子树大小}
那我们要先求出v中层次最高的以颜色i的结点为根的子树大小,所以用sum表示目前为止颜色i的所有子树的大小,用last存下v进入dfs前的sum,即不算v下颜色i的子树时的sum
到v中去dfs,然后把这些子树的大小加到sum中
当遍历完v的所有子树后,我们会发现 sum-last 就是v下层次最高的以颜色i的结点为根的子树大小的总和
那么v贡献出的联通块大小就是 k = size[v]-(sum-last)
然后去u的下一棵子树v'进行同样的操作,直到遍历完u的所有子树,此时sum已经变成了到u为止(不包含u)的以颜色i为根的子树大小之和
为了计算u的颜色为i的祖先,需要把u也并入sum里,那么只需要在sum里加入u自己,再加上所有v贡献的联通块即可
现在已经维护完所有的信息,可以推出u的dfs,往上回滚求出u的祖先结点的信息
再来考虑如何在一次dfs中对所有出现的颜色进行计算,
我们可以在上面的递归中发现,每中颜色在求贡献只用到了size[],还有每种颜色对应的sum,那么用sum[c]数组来维护颜色c代表的sum即可,就可以在一次dfs中维护多种颜色的贡献
本题和虚树有些类似的地方,首先把每种颜色当成是一个询问,就类似虚树的询问了
然后是回滚dfs的过程,自叶子往上求(其实是按照dfs序)的方式:在初次碰到u时记录进入dfs前的状态,然后dfs处理完其所有子节点的状态后再来计算u的状态
#include<bits/stdc++.h> #include<vector> using namespace std; #define maxn 200005 #define ll long long ll ans,color[maxn],size[maxn],sum[maxn]; vector<int>G[maxn]; void dfs1(int u,int pre){ size[u]=1; for(int i=0;i<G[u].size();i++){ int v=G[u][i]; if(v==pre)continue; dfs1(v,u); size[u]+=size[v]; } } //这个树形dp最重要的是理解sum[]数组的含义,sum[x]的更新像虚树的加边一样是自叶子节点往上回滚的 void dfs2(int u,int pre){ ll other=0;//other表示为size[u]减去u下所有最高的以color[u]为根的大小 for(int i=0;i<G[u].size();i++){ int v=G[u][i]; if(v==pre)continue; ll last=sum[color[u]];//记录前前面子树里颜色u的子树(虚树)里的值 dfs2(v,u); ll diff=sum[color[u]]-last;//v的子树里颜色为color[u]的个数 //v树下不包含color[u]的联通块的大小 ans+=(size[v]-diff-1)*(size[v]-diff)/2; other+=size[v]-diff; } sum[color[u]]+=other+1;//+1是因为u本身也是color[u] } int f[maxn],tot; int main(){ ll n,t=0; while(cin>>n){ ++t; for(int i=1;i<=n;i++)G[i].clear(); memset(f,0,sizeof f); tot=0; for(int i=1;i<=n;i++){ scanf("%d",&color[i]); f[color[i]]=1; } for(int i=1;i<=n;i++)tot+=f[i]; for(int i=1;i<n;i++){ int u,v; scanf("%d%d",&u,&v); G[u].push_back(v); G[v].push_back(u); } if(tot==1){ printf("Case #%d: %lld ",t,n*(n-1)/2); continue; } memset(size,0,sizeof size); memset(sum,0,sizeof sum); ans=0; dfs1(1,1);dfs2(1,1); for(int i=1;i<=n;i++) if(f[i]) ans+=(n-sum[i])*(n-sum[i]-1)/2; ll tmp=(n-1)*n/2*tot; printf("Case #%d: %lld ",t,tmp-ans); } }