sol
点分。我们面临的最主要一个问题,就是如何在(O(n))的时间内算出所有LCA为根的点对的贡献,还要分别累加到它们自己的答案中去。
(num_i):每一种颜色的数量。你可以认为这就是一个桶。从根到叶子遍历,相当于每次都只维护一条链上的颜色情况。以便于得到(tot_i)
(fst_i):(i)号点上的颜色是不是从根往下第一次出现。如果是,就会加到(col_i)里面取算贡献
(col_i):每一种颜色的贡献
(tot_i):每个点到根的路径上有多少种颜色
鉴于点对之间计算答案不太现实,我们考虑计算每种颜色对答案的贡献。
如果一个节点(i),它的颜色是从根往下第一次出现的(即(fst_i=1)),那么这种颜色就一定会给其他子树中的每个节点贡献(sz_i)的答案。这个答案就累加在(col_i)中。然后在对这个(col_i)求和,就是总贡献。
一个节点(i)的答案的初始值应该是(tot_i*(sz_u-sz_v))(就是总(sz)除去自己所在的子树外的部分),然后还要加上一些(col_i)的值,但是要保证加上的(col_i)不能是自己到根已经有过的颜色(不然就重复计算了)。
多做几遍dfs,维护以上提到的东西就行了。
复杂度是(O(nlog_2n))的,带点小常数
code
我之前干了件非常傻逼的事情
我没写(clear)在(solve)函数里面写了个memset
然后复杂度变成了严格(O(n^2))
然后就只有暴力分。。。
#include<cstdio>
#include<algorithm>
using namespace std;
#define ll long long
const int N = 100005;
int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
struct edge{int to,next;}a[N<<1];
int n,c[N],head[N],cnt,sz[N],w[N],vis[N],sum,root;
int num[N],fst[N],col[N],tot[N];
ll sigma,ans[N];
void getroot(int u,int f)
{
sz[u]=1;w[u]=0;
for (int e=head[u];e;e=a[e].next)
{
int v=a[e].to;if (v==f||vis[v]) continue;
getroot(v,u);
sz[u]+=sz[v];w[u]=max(w[u],sz[v]);
}
w[u]=max(w[u],sum-sz[u]);
if (w[u]<w[root]) root=u;
}
void dfs(int u,int f,ll &Ans)
{
sz[u]=1;num[c[u]]++;
if (num[c[u]]==1) fst[u]=1,cnt++;else fst[u]=0;
tot[u]=cnt;Ans+=tot[u];
for (int e=head[u];e;e=a[e].next)
{
int v=a[e].to;if (v==f||vis[v]) continue;
dfs(v,u,Ans);sz[u]+=sz[v];
}
if (fst[u]) col[c[u]]+=sz[u],sigma+=sz[u],cnt--;
num[c[u]]--;
}
void change(int u,int f,int b)
{
if (fst[u]) col[c[u]]+=b*sz[u],sigma+=b*sz[u];
for (int e=head[u];e;e=a[e].next)
{
int v=a[e].to;if (v==f||vis[v]) continue;
change(v,u,b);
}
}
void calc(int u,int f,int k)
{
if (fst[u]) sigma-=col[c[u]];
ans[u]+=1ll*tot[u]*k+sigma;
for (int e=head[u];e;e=a[e].next)
{
int v=a[e].to;if (v==f||vis[v]) continue;
calc(v,u,k);
}
if (fst[u]) sigma+=col[c[u]];
}
void clear(int u,int f)
{
col[c[u]]=0;
for (int e=head[u];e;e=a[e].next)
{
int v=a[e].to;if (v==f||vis[v]) continue;
clear(v,u);
}
}
void solve(int u)
{
vis[u]=1;
dfs(u,0,ans[u]);
col[c[u]]-=sz[u];sigma-=sz[u];
for (int e=head[u];e;e=a[e].next)
{
int v=a[e].to;if (vis[v]) continue;
change(v,0,-1);
calc(v,0,sz[u]-sz[v]);
change(v,0,1);
}
clear(u,0);sigma=0;
for (int e=head[u];e;e=a[e].next)
{
int v=a[e].to;if (vis[v]) continue;
sum=sz[v];
root=0;
getroot(v,0);
solve(root);
}
}
int main()
{
n=gi();
for (int i=1;i<=n;i++) c[i]=gi();
for (int i=1;i<n;i++)
{
int u=gi(),v=gi();
a[++cnt]=(edge){v,head[u]};head[u]=cnt;
a[++cnt]=(edge){u,head[v]};head[v]=cnt;
}
sum=w[0]=n;cnt=0;
getroot(1,0);
solve(root);
for (int i=1;i<=n;i++) printf("%lld
",ans[i]);
return 0;
}