约定:一棵树的深度定义为其中到根最远的点到根的距离
考虑节点$x$的答案:
任取一条直径,根据直径的性质,到$x$较远的直径端点一定是到$x$最远的点之一
由此,不难证明对于$x$独特的点,一定在$x$到"到$x$较远的直径端点"的路径上
分别以直径的两个端点为根(做两次),每一次求出$x$到根路径上所有对$x$独特的点的特产数,根据前面所述,两次的最大值即为最终答案
关于这件事情,维护一个"目前独特"的集合$S$,具体定义如下:
$k$到根路径上的节点$x$中(不包括$k$),满足不存在$k$子树外的节点$y e x$,使得$k$到$x$和$y$的距离相等("$k$子树"指"以$k$为根的子树")
当递归到节点$k$时,令$mx$为$k$子树深度,将$S$中到$k$距离不超过$mx$的节点都删除,此时$S$即$k$到根路径上所有与对$k$独特的点(所构成的集合),求出其中特产数即可
接下来,考虑递归儿子$son$,令$mx$为以$k$的其余儿子子树深度的最大值+1,将$S$中到$k$距离不超过$mx$的节点都删除,然后在加入$k$(显然满足条件),并在递归完$son$后还原$S$
显然,这样做的复杂度过高,我们需要优化实现:
将$S$用一个栈维护(栈顶深度最大),"将$S$中到$k$距离不超过$mx$的节点都删除"即不断弹出栈顶
将树长链剖分,(对于节点$k$)令$mx$为$k$子树深度,$cmx$为轻儿子子树深度的最大值+1,接下来交换之前操作的顺序,即先递归重儿子,再递归轻儿子,最后统计答案
考虑递归重儿子,即将$S$中到$k$距离不超过$cmx$的节点都删除,而递归轻儿子和统计答案都是将$S$中到$k$距离不超过$mx$的节点都删除,也就不需要还原,直接操作即可
统计答案后,也不需要还原$S$,假设以此法删除的点$x$,由于删去的是到$k$距离不超过$mx$的节点,即$k$子树内必然存在节点$y$使得$k$到$x$和$y$的距离相等
那么对$x$子树内、$k$子树外的点$z$,对于$z$来说$x$一定不"独特",只需要令$y'$为将$y$向上爬$k$到$lca(k,z)$的距离步后的节点,此时$z$到$x$和$y'$的距离显然相等
由此,发现最多在$S$中加入$o(n)$个元素,那么也即至多删除$o(n)$次,复杂度为$o(n)$
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 200005 4 struct Edge{ 5 int nex,to; 6 }edge[N<<1]; 7 stack<int>st; 8 int E,n,m,x,y,a[N],head[N],dep[N],l[N],cl[N],mx[N],tot[N],ans[N]; 9 void add(int x,int y){ 10 edge[E].nex=head[x]; 11 edge[E].to=y; 12 head[x]=E++; 13 } 14 void dfs(int k,int fa,int s){ 15 dep[k]=s; 16 mx[k]=l[k]=cl[k]=0; 17 for(int i=head[k];i!=-1;i=edge[i].nex) 18 if (edge[i].to!=fa){ 19 dfs(edge[i].to,k,s+1); 20 int x=l[edge[i].to]+1; 21 if (l[k]<x){ 22 mx[k]=edge[i].to; 23 swap(l[k],x); 24 } 25 cl[k]=max(cl[k],x); 26 } 27 } 28 void add(int k){ 29 st.push(k); 30 if (++tot[a[k]]==1)ans[0]++; 31 } 32 void del(){ 33 if (--tot[a[st.top()]]==0)ans[0]--; 34 st.pop(); 35 } 36 void calc(int k,int fa){ 37 if (fa)add(fa); 38 while ((!st.empty())&&(dep[k]-dep[st.top()]<=cl[k]))del(); 39 if (mx[k])calc(mx[k],k); 40 while ((!st.empty())&&(dep[k]-dep[st.top()]<=l[k]))del(); 41 ans[k]=max(ans[k],ans[0]); 42 for(int i=head[k];i!=-1;i=edge[i].nex) 43 if ((edge[i].to!=fa)&&(edge[i].to!=mx[k]))calc(edge[i].to,k); 44 if ((!st.empty())&&(st.top()==fa))del(); 45 } 46 int main(){ 47 scanf("%d%d",&n,&m); 48 memset(head,-1,sizeof(head)); 49 for(int i=1;i<n;i++){ 50 scanf("%d%d",&x,&y); 51 add(x,y); 52 add(y,x); 53 } 54 for(int i=1;i<=n;i++)scanf("%d",&a[i]); 55 dfs(1,0,0); 56 x=1; 57 for(int i=2;i<=n;i++) 58 if (dep[x]<dep[i])x=i; 59 dfs(x,0,0); 60 calc(x,0); 61 x=1; 62 for(int i=2;i<=n;i++) 63 if (dep[x]<dep[i])x=i; 64 dfs(x,0,0); 65 calc(x,0); 66 for(int i=1;i<=n;i++)printf("%d ",ans[i]); 67 }