比较好的一道虚树题.
建出虚树,然后计算虚树中距离点 $x$ 最近的关键点,这个来一次树形dp+换根即可实现.
难点在于计算 $x$ 到 $x$ 父亲这一段所有节点归属于谁(肯定属于 $x$ 的最近点或 $x$ 父亲最近点).
这里的话肯定可以二分出拐点(拐点以前属于 $x$,拐点以后属于 $y$),然后根据虚树的性质,$x$ 到父亲之间节点的儿子上肯定都没有关键节点.
很多地方都需要用到倍增,dfs 处理倍增的时候要注意先处理父亲的倍增数组再处理子树的倍增数组.
code:
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 300009 #define ll long long #define setIO(s) freopen(s".in","r",stdin) using namespace std; int edges,n,tim,m,top; int hd[N],to[N<<1],nex[N<<1],fa[20][N],dep[N],dfn[N]; int a[N],sta[N],size[N],mk[N],b[N],ans[N]; vector<int>G[N]; struct node { int x,y; node(int t1=N,int t2=N){x=t1,y=t2;} node operator+(const node b) const { node c; if(y!=b.y) { if(y<b.y) c.x=x,c.y=y; else c=b; } else { c.y=y,c.x=min(x,b.x); } return c; } }mn[N]; bool cmp(int i,int j) { return dfn[i]<dfn[j]; } void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void addvir(int x,int y) { G[x].push_back(y); } void dfs(int x,int ff) { fa[0][x]=ff; dep[x]=dep[ff]+1; dfn[x]=++tim,size[x]=1; for(int i=1;i<20;++i) fa[i][x]=fa[i-1][fa[i-1][x]]; for(int i=hd[x];i;i=nex[i]) if(to[i]!=ff) { dfs(to[i],x); size[x]+=size[to[i]]; } } int get_lca(int x,int y) { if(dep[x]!=dep[y]) { if(dep[x]>dep[y]) swap(x,y); for(int i=19;i>=0;--i) if(dep[fa[i][y]]>=dep[x]) y=fa[i][y]; } if(x==y) return x; for(int i=19;i>=0;--i) if(fa[i][x]!=fa[i][y]) { x=fa[i][x],y=fa[i][y]; } return fa[0][x]; } int go_kth(int x,int k) { for(int i=19;i>=0;--i) { if(dep[x]-dep[fa[i][x]]<=k) { k-=(dep[x]-dep[fa[i][x]]); x=fa[i][x]; } } return x; } void ins(int x) { if(top<=1) sta[++top]=x; else { int lca=get_lca(x,sta[top]); if(lca==sta[top]) sta[++top]=x; else { while(top>1&&dep[sta[top-1]]>=dep[lca]) { addvir(sta[top-1],sta[top]); --top; } if(sta[top]!=lca) addvir(lca,sta[top]),sta[top]=lca; sta[++top]=x; } } } void dfs1(int x) { mn[x]=node(); if(mk[x]) mn[x]=node(x,0); for(int i=0;i<G[x].size();++i) { dfs1(G[x][i]); mn[x]=mn[x]+node(mn[G[x][i]].x,dep[mn[G[x][i]].x]-dep[x]); } } void dfs2(int x,int ff) { if(ff) { if(mn[ff].x!=mn[x].x) { mn[x]=mn[x]+node(mn[ff].x,mn[ff].y+dep[x]-dep[ff]); } } for(int i=0;i<G[x].size();++i) dfs2(G[x][i],x); } void dfs3(int x,int ff) { int sz=0,y; for(int i=0;i<G[x].size();++i) { dfs3(G[x][i],x); int tmp=G[x][i]; y=G[x][i]; y=go_kth(y,dep[y]-dep[x]-1); sz+=size[y]; } ans[mn[x].x]+=size[x]-sz; if(ff){ int len=dep[x]-dep[ff]; int s=fa[0][x]; int t=go_kth(x,len-1); if(mn[x].x==mn[ff].x) { ans[mn[x].x]+=size[t]-size[x]; } else { int X=mn[x].y; int Y=mn[ff].y; int re=0,l=1,r=len-1; while(l<=r) { int mid=(l+r)>>1; if(2*mid<Y-X+len) re=mid,l=mid+1; else r=mid-1; } int p=go_kth(x,re); ans[mn[x].x]+=size[p]-size[x]; ans[mn[ff].x]+=size[t]-size[p]; if(re+1<len&&re+1+X==len-re-1+Y&&mn[x].x<mn[ff].x) { ans[mn[x].x]+=size[fa[0][p]]-size[p]; ans[mn[ff].x]-=size[fa[0][p]]-size[p]; } } } } void dfs4(int x) { mk[x]=0; for(int i=0;i<G[x].size();++i) dfs4(G[x][i]); G[x].clear(),mn[x]=node(); } void solve() { scanf("%d",&m); for(int i=1;i<=m;++i) { scanf("%d",&a[i]); mk[a[i]]=1,b[i]=a[i]; } sort(a+1,a+1+m,cmp); top=0; if(a[1]!=1) ins(1); for(int i=1;i<=m;++i) ins(a[i]); while(top>1) addvir(sta[top-1],sta[top]),--top; dfs1(1),dfs2(1,0),dfs3(1,0),dfs4(1); for(int i=1;i<=m;++i) { printf("%d",ans[b[i]]),ans[b[i]]=0; if(i<m) printf(" "); } printf(" "); } int main() { // setIO("input"); int x,y,z; scanf("%d",&n); for(int i=1;i<n;++i) { scanf("%d%d",&x,&y); add(x,y),add(y,x); } dfs(1,0); int Q; scanf("%d",&Q); for(int i=1;i<=Q;++i) solve(); return 0; }