题目链接:
对于每个点,它的答案最大就是与它距离最远的点的距离。
而如果与它距离为$x$的点有大于等于两个,那么与它距离小于等于$x$的点都不会被计入答案。
所以我们需要找到对于每个点$u$距离它最远的点及最小的距离$x$满足距离$u$的距离大于等于$x$的点都只有一个。
那么怎么找距离每个点最远的点?
这个点自然就是树的直径的一个端点了!
我们将树的直径先找到,然后讨论一下对于每个点,有哪些点可能会被计入答案:
如图所示,我们以点$x$为例,假设它距离直径两端点中的$S$较近($y$为$x$距离直径上最近的点),设$dis$代表两点距离:
对于$y$左边点所有点,显然$S$与$y$的距离最远,但$dis(S,y)<dis(T,y)$,所以$y$左边的所有点都不会被计入答案。
对于在$x$子树中的点,他们与$x$的距离要小于$dis(y,T)$,也就小于$dis(x,T)$,所以不会被计入答案。
对于在$y$子树中但不在$x$子树中的点(例如$b$),因为$dis(y,b)le dis(y,S)$,所以$dis(b,d)<dis(S,d)$,不会被计入答案。
对于$y$与$T$之间的点的子树中的点(例如$c$),显然$dis(y,c)le dis(y,T)$,所以这类点不会被计入答案。
那么综上所述对于靠近$S$的点,只有$x$到$T$之间的点才有可能被计入答案,对于靠近$T$的点同理。
所以我们只需要分别以$S$和$T$为根遍历整棵树,用一个单调栈保存每个点到根的这条链上能被计入答案的点即可。
求不同权值个数,再开一个桶记录栈中每种权值的个数,每次进栈或弹栈时对应加减。
因为答案与深度有关,我们将原树长链剖分。
对于每个点,当走重儿子时,求出所有轻儿子的子树中的最长链长度$len$,将当前栈中与$x$距离小于等于$len$的点弹出;当遍历轻儿子时,将当前栈中与$x$距离小于等于$x$往下最长链长度的点弹出。
注意要在弹栈之后再把$x$压入栈中,而且遍历每个儿子前都要重新将$x$压入栈中。
最后将以$S$为根时的答案与以$T$为根时的答案取最大值即可。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<vector> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> using namespace std; int head[200010]; int to[400010]; int dep[200010]; int next[400010]; int mx[200010]; int son[200010]; int st[200010]; int top; int tot; int S,T; int res; int ans[200010]; int n,m; int x,y; int col[200010]; int cnt[200010]; void add(int x,int y) { next[++tot]=head[x]; head[x]=tot; to[tot]=y; } void pop() { cnt[col[st[top]]]--; res-=(cnt[col[st[top]]]==0); top--; } void push(int x) { st[++top]=x; cnt[col[x]]++; res+=(cnt[col[x]]==1); } void dfs(int x,int fa) { dep[x]=dep[fa]+1; for(int i=head[x];i;i=next[i]) { if(to[i]!=fa) { dfs(to[i],x); } } } void dfs1(int x,int fa) { son[x]=0; mx[x]=0; dep[x]=dep[fa]+1; for(int i=head[x];i;i=next[i]) { if(to[i]!=fa) { dfs1(to[i],x); if(mx[to[i]]>mx[son[x]]) { son[x]=to[i]; } } } mx[x]=mx[son[x]]+1; } void dfs2(int x,int fa) { if(!son[x]) { ans[x]=max(ans[x],res); return ; } int len=0; for(int i=head[x];i;i=next[i]) { if(to[i]!=fa&&to[i]!=son[x]) { len=max(len,mx[to[i]]); } } while(top&&dep[st[top]]>=dep[x]-len) { pop(); } push(x); dfs2(son[x],x); for(int i=head[x];i;i=next[i]) { if(to[i]!=fa&&to[i]!=son[x]) { while(top&&dep[st[top]]>=dep[x]-mx[son[x]]) { pop(); } push(x); dfs2(to[i],x); } } while(top&&dep[st[top]]>=dep[x]-mx[son[x]]) { pop(); } ans[x]=max(ans[x],res); } int main() { scanf("%d%d",&n,&m); for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } for(int i=1;i<=n;i++) { scanf("%d",&col[i]); } dfs(1,0); for(int i=1;i<=n;i++) { S=dep[i]>dep[S]?i:S; } dfs(S,0); for(int i=1;i<=n;i++) { T=dep[i]>dep[T]?i:T; } dfs1(S,0); dfs2(S,0); dfs1(T,0); dfs2(T,0); for(int i=1;i<=n;i++) { printf("%d ",ans[i]); } }