点分治真是一个好东西。可惜我不会
这种要求所有路经的题很可能是点分治。
然后我就不会了。。
既然要用点分治,就想,点分治有哪些优点?它可以(O(nlogn))遍历分治树的所有子树。
那么现在的问题就是,如可快速((O(n))或O((nlogn)))求以一个点为根的时候,子树之间的贡献(当然还有根节点的)。
我们注意到一件事,就是一棵子树中一个点对其他子树的点产生贡献当且仅当这个点的颜色在它到根的路径上第一次出现(或者说只算上这些贡献答案正确),且贡献为以这个点为根的子树大小。(不考虑其它子树的颜色)
这个有什么用,我们可以遍历两遍子树,第一遍预处理出所有子树对其它子树的贡献(如上边一段所说把贡献统计),第二次遍历每一颗子树先把这颗树的贡献去掉,统计所有其它的树对这颗树的贡献。
那么具体该怎么做?
void calc(int u){
dfs1(u,0);
ans[u]+=sum;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(vis[v])continue;
cnt[a[u]]++;
sum-=size[v];color[a[u]]-=size[v];
change(v,u,-1);
cnt[a[u]]--;
tot=size[u]-size[v];
dfs2(v,u);
cnt[a[u]]++;
sum+=size[v];color[a[u]]+=size[v];
change(v,u,1);
cnt[a[u]]--;
}
clear(u,0);
}
首先dfs1是统计贡献的,用sum记录贡献和,color[i]记录第i种颜色的贡献。
然后根的答案就可以累加了。
那么如可判断一个颜色第一次出现?可以记录一个cnt[i]记录第i种颜色在到根的路径上出现多少次。当cnt[i]等于1的时候统计贡献。
然后
cnt[a[u]]++;
sum-=size[v];color[a[u]]-=size[v];
change(v,u,-1);
cnt[a[u]]--;
用来消除子树贡献。dfs2统计其它子树对这颗子树的贡献。
void dfs2(int u,int f){
cnt[a[u]]++;
if(cnt[a[u]]==1){
sum-=color[a[u]];
num++;
}
ans[u]+=sum+num*tot;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f||vis[v])continue;
dfs2(v,u);
}
if(cnt[a[u]]==1){
sum+=color[a[u]];
num--;
}
cnt[a[u]]--;
}
如果这颗子树中出现一个颜色,并且它是第一次出现,那么减去所有子树的color[a[u]],加上其它子树的节点总数,因为每一条到其它子树的路径都会产生贡献,这也是我们一开始不考虑贡献对其他子树影响的原因,因为遍历子树的时候会把这些重复的贡献减去。
更具体还是看代码。
// luogu-judger-enable-o2
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
#define int long long
const int N=101000;
int Cnt,head[N];
int g[N],size[N],cnt[N],a[N],sum,color[N],tot,num,root,all,vis[N],ans[N],n;
struct edge{
int to,nxt;
}e[N*2];
void add_edge(int u,int v){
Cnt++;
e[Cnt].nxt=head[u];
e[Cnt].to=v;
head[u]=Cnt;
}
int read(){
int sum=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
return sum*f;
}
void getroot(int u,int f){
g[u]=0;size[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f||vis[v])continue;
getroot(v,u);
size[u]+=size[v];
g[u]=max(g[u],size[v]);
}
g[u]=max(g[u],all-size[u]);
if(g[u]<g[root])root=u;
}
void dfs1(int u,int f){
cnt[a[u]]++;
size[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f||vis[v])continue;
dfs1(v,u);
size[u]+=size[v];
}
if(cnt[a[u]]==1){
sum+=size[u];
color[a[u]]+=size[u];
}
cnt[a[u]]--;
}
void clear(int u,int f){
cnt[a[u]]++;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f||vis[v])continue;
clear(v,u);
}
if(cnt[a[u]]==1){
sum-=size[u];
color[a[u]]-=size[u];
}
cnt[a[u]]--;
}
void dfs2(int u,int f){
cnt[a[u]]++;
if(cnt[a[u]]==1){
sum-=color[a[u]];
num++;
}
ans[u]+=sum+num*tot;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f||vis[v])continue;
dfs2(v,u);
}
if(cnt[a[u]]==1){
sum+=color[a[u]];
num--;
}
cnt[a[u]]--;
}
void change(int u,int f,int k){
cnt[a[u]]++;
if(cnt[a[u]]==1){
sum+=k*size[u];color[a[u]]+=k*size[u];
}
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f||vis[v])continue;
change(v,u,k);
}
cnt[a[u]]--;
}
void calc(int u){
dfs1(u,0);
ans[u]+=sum;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(vis[v])continue;
cnt[a[u]]++;
sum-=size[v];color[a[u]]-=size[v];
change(v,u,-1);
cnt[a[u]]--;
tot=size[u]-size[v];
dfs2(v,u);
cnt[a[u]]++;
sum+=size[v];color[a[u]]+=size[v];
change(v,u,1);
cnt[a[u]]--;
}
clear(u,0);
}
void work(int u){
calc(u);
vis[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(vis[v])continue;
root=0,all=size[v];
getroot(v,0);
work(root);
}
}
signed main(){
n=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<n;i++){
int u=read(),v=read();
add_edge(u,v);add_edge(v,u);
}
g[0]=n+10;root=0;all=n;
getroot(1,0);work(root);
for(int i=1;i<=n;i++)printf("%lld
",ans[i]);
return 0;
}