题面
https://www.luogu.org/problem/P3233
题解
我构思和实现这道题写了很长世界,事实上,一开始就想好的话,不需要花费如此长的时间的。
只对关键点和关键点的$lca$建一个虚树。
定义虚树上的每个点的控制范围是只走不经过虚树路径上的点,能够到达的点集,这个可以建出虚树之后用儿子更新父亲,把儿子的贡献减去即可。
从上向下$dp$一遍,再从下到上$dp$一遍,对于虚树上的每条边,找分界点。
之前想的方法是分界点在两相邻关键点之间寻找,但是并不是两相邻关键点之间的点总是被二者之一控制(可能还向上)
#include<queue> #include<stack> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define ri register int #define N 300050 #define INF 500000 using namespace std; int n,m,q,vt[N],vt2[N<<1],in[N],TIM; int col[N]; int dfin[N],dfou[N],tmp[N],vis[N]; int ev[N],pv[N],ffa[N],ans[N],dis[N]; int down[N],imp[N]; stack<int> s; vector<int> son[N]; struct Basic_tree { vector<int> to[N]; int fa[N],dep[N],pa[N][20],siz[N]; void add_edge(int u,int v) { to[u].push_back(v); to[v].push_back(u); } void maketree(int x,int ff) { fa[x]=ff; dep[x]=dep[ff]+1; pa[x][0]=ff; for (ri i=1;i<=19;i++) pa[x][i]=pa[pa[x][i-1]][i-1]; dfin[x]=++TIM; siz[x]=1; for (ri i=0;i<to[x].size();i++) { int y=to[x][i]; if (y==ff) continue; maketree(y,x); siz[x]+=siz[y]; } dfou[x]=++TIM; } int lca(int x,int y) { if (dep[x]<dep[y]) swap(x,y); for (ri i=19;i>=0;i--) if (dep[pa[x][i]]>=dep[y]) x=pa[x][i]; if (x==y) return x; for (ri i=19;i>=0;i--) if (pa[x][i]!=pa[y][i]) x=pa[x][i],y=pa[y][i]; return fa[x]; } int getf(int x,int t) { for (ri i=19;i>=0;i--) if (t&(1<<i)) x=pa[x][i]; return x; } int find(int x,int y,int d1,int d2,int x0,int y0) { int tl=d1+d2+dep[x]-dep[y]; if (tl%2==0) { int m0=getf(x,tl/2-d1); if (x0<y0) return m0; else return getf(x,tl/2-d1-1); } else { return getf(x,tl/2-d1); } } int jump(int x,int y) { for (ri i=19;i>=0;i--) if (dep[pa[x][i]]>dep[y]) x=pa[x][i]; return x; } int distance(int u,int v) { int w=lca(u,v); return dep[u]+dep[v]-2*dep[w]; } } T; bool cmp(int a,int b) { return dfin[a]<dfin[b]; } bool cmp2(int a,int b) { int k1,k2; if (a<0) k1=dfou[-a]; else k1=dfin[a]; if (b<0) k2=dfou[-b]; else k2=dfin[b]; return k1<k2; } void dp(int x) { if (imp[x]) { dis[x]=0; down[x]=x; } for (ri i=0;i<son[x].size();i++) { int y=son[x][i]; dp(y); if (T.distance(down[y],x)<dis[x] || T.distance(down[y],x)==dis[x] && down[y]<down[x]) { dis[x]=T.distance(down[y],x); down[x]=down[y]; } } } void redp(int x) { for (ri i=0;i<son[x].size();i++) { int y=son[x][i]; if (T.distance(down[x],y)<dis[y] || T.distance(down[x],y)==dis[y] && down[x]<down[y]) { down[y]=down[x]; dis[y]=T.distance(down[x],y); } redp(y); } } int main() { scanf("%d",&n); for (ri i=1,u,v;i<n;i++) { scanf("%d %d",&u,&v); T.add_edge(u,v); } memset(dis,0x3f,sizeof(dis)); T.maketree(1,1); scanf("%d",&q); while (q--) { int k,t; scanf("%d",&k); int cnt=0,cc=0; for (ri j=1;j<=k;j++) { scanf("%d",&t); tmp[++cnt]=t; vt[cnt]=t; in[t]=1; } sort(vt+1,vt+cnt+1,cmp); for (ri i=1;i<cnt;i++) { int r=T.lca(vt[i],vt[i+1]); if (!in[r]) { in[r]=1; vt2[++cc]=r; vt2[++cc]=-r; } } for (ri i=1;i<=cnt;i++) { vt2[++cc]=vt[i]; vt2[++cc]=-vt[i]; } sort(vt2+1,vt2+cc+1,cmp2); for (ri i=1;i<=cc;i++) { if (vt2[i]>0) { s.push(vt2[i]); pv[vt2[i]]+=T.siz[vt2[i]]; continue; } int x=s.top(),y; s.pop(); if (s.empty()) { pv[x]+=n-T.siz[x]; ev[x]=0; } else { y=s.top(); son[y].push_back(x); ffa[x]=y; int z=T.jump(x,y); ev[x]=T.siz[z]-T.siz[x]; pv[y]-=T.siz[z]; } } for (ri i=1;i<=cnt;i++) imp[tmp[i]]=1; dp(vt2[1]); redp(vt2[1]); for (ri i=1;i<=cc;i++) if (vt2[i]>0) { ans[down[vt2[i]]]+=pv[vt2[i]]; } while (!s.empty()) s.pop(); for (ri i=1;i<=cc;i++) { if (vt2[i]>0) { s.push(vt2[i]); continue; } int x=s.top(),y; s.pop(); if (!s.empty()) { y=s.top(); int z=T.jump(x,y); if (down[x]==down[y]) { ans[down[x]]+=T.siz[z]-T.siz[x]; } else { int m0=T.find(x,y,dis[x],dis[y],down[x],down[y]); ans[down[x]]+=T.siz[m0]-T.siz[x]; ans[down[y]]+=T.siz[z]-T.siz[m0]; } } } for (ri i=1;i<=cnt;i++) printf("%d ",ans[tmp[i]]); puts(""); for (ri i=1;i<=cc;i++) if (vt2[i]>0) { in[vt2[i]]=0; imp[vt2[i]]=0; dis[vt2[i]]=INF; ans[vt2[i]]=0; ev[vt2[i]]=0; pv[vt2[i]]=0; ffa[vt2[i]]=0; ans[vt2[i]]=0; son[vt2[i]].clear(); } } return 0; }