题意
还是一个n个结点的树,每次询问还是选定k个点。规定每个点会给距离它最近的标记点(距离相同,编号最小)贡献1的权值,每次询问即是标记k个点,然后问这k个点的权值。
(N leqslant 300000, q leqslant 300000,k_1+k_2+...+k_q leqslant 300000)
题解
算法还是虚树。
然后有一个显然的性质。
对于每个点x,它的所有儿子的子树,内部有标记点的子树除外,所贡献的点一定与点x相同。
然后我们先只考虑虚树上的结点。
在虚树上,很容易可以算出每个点距离最近的标记点。
那么我们也可以连带把每个点连带的子树的贡献也统计好。
那么还剩下哪些点没统计呢?
虚树边上的点以及它们所连带的子树。
而对于边<u,v>上的点,一定贡献给u或v所贡献的点。
我们画个图考虑一下。
点X和点Y为虚树上的点(非标记点)。距离点X最近的标记点的编号为5,到点X的距离为3,到点Y最近的标记点的编号为2,距离为4。首先虚树边<X,Y>上与点x最近的点a所连带的1号子树,一定是贡献给5号点,同时得到5号点到点a的距离为4,与到2号点到点Y的距离相等。那么这之间的点一定被2号点与5号点平分,上面一半归5号点,下面一般归2号点,但此时存在点c,到2号点和5号点的距离均相等,采取距离相等,编号最小的原则,点c及2号子树,归2号点所有,也就是说点c及以下贡献2号点,点b及以上贡献5号点,到此分配完成。
简单来说,我们对于每个虚树边<u,v>,我们根据两端点贡献点的信息,算出边上的断点,断点以上给点u的贡献点,断点以下给点v的贡献点。具体详见代码。
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define inf 0x7f7f7f
using namespace std;
const int maxn=3e5;
int n,m,tot,root,Time;
int Lg[maxn+8],a[maxn+8];
int pre[maxn*2+8],now[maxn+8],son[maxn*2+8];
int dep[maxn+8],f[maxn+8][20],siz[maxn+8],dfn[maxn+8];
int st[maxn+8],ans[maxn+8];
int read()
{
int x=0,f=1;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x*f;
}
bool cmp(int x,int y){return dfn[x]<dfn[y];}
void add(int u,int v)
{
pre[++tot]=now[u];
now[u]=tot;
son[tot]=v;
}
void dfs(int x,int fa)
{
dfn[x]=++Time;
dep[x]=dep[fa]+1;
f[x][0]=fa;
siz[x]=1;
for (int i=1;i<=log(dep[x])/log(2);i++) f[x][i]=f[f[x][i-1]][i-1];
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (child==fa) continue;
dfs(child,x);
siz[x]+=siz[child];
}
}
int jump(int x,int d){if (d<0) return 0;for (;d;d-=d&(-d)) x=f[x][Lg[d&(-d)]];return x;}
int Get_Lca(int x,int y)
{
if (dep[x]>dep[y]) swap(x,y);
y=jump(y,dep[y]-dep[x]);
if (x==y) return x;
for (int i=log(dep[x])/log(2);~i&&f[x][0]!=f[y][0];i--)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
struct Pnt
{
int x,dis;
};
bool operator <(Pnt a,Pnt b){return a.dis!=b.dis?a.dis<b.dis:a.x<b.x;}
Pnt operator +(Pnt a,int b){return (Pnt){a.x,a.dis+b};}
struct Virtual_Tree
{
int tot,tail;
int st[maxn+8];
int pre[maxn*2+8],now[maxn+8],son[maxn*2+8],val[maxn*2+8];
int color[maxn+8],cnt[maxn+8];
Pnt f[maxn+8];
void clear()
{
tot=0;
while(tail) now[st[tail--]]=0;
for (int i=1;i<=m;i++) color[a[i]]=0,ans[i]=0;
}
void add(int u,int v,int w)
{
if (!now[u]) st[++tail]=u;
pre[++tot]=now[u];
now[u]=tot;
son[tot]=v;
val[tot]=w;
}
void insert(int u,int v)
{
if (dep[u]>dep[v]) swap(u,v);
add(u,v,dep[v]-dep[u]);
add(v,u,dep[v]-dep[u]);
}
void dfs1(int x,int fa)
{
f[x]=(Pnt){x,color[x]?0:inf};
cnt[x]=siz[x];
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (child==fa) continue;
dfs1(child,x);
f[x]=min(f[x],f[child]+val[p]);
cnt[x]-=siz[jump(child,val[p]-1)];
}
}
void dfs2(int x,int fa)
{
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (child==fa) continue;
f[child]=min(f[child],f[x]+val[p]);
dfs2(child,x);
}
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (child==fa) continue;
int d=f[x].dis-f[child].dis+val[p]-1,tmp1=siz[jump(child,val[p]-1)]+cnt[child]-siz[child],tmp2=siz[jump(child,d/2+((d>0)&(d&1)&(f[child].x<f[x].x)))]+cnt[child]-siz[child];
tmp2=max(tmp2,0);tmp1-=tmp2;
ans[color[f[x].x]]+=tmp1;
ans[color[f[child].x]]+=tmp2;
}
}
}VT;
void solve()
{
m=read();
for (int i=1;i<=m;i++) VT.color[a[i]=read()]=i;
sort(a+1,a+m+1,cmp);
int tail=1;
st[tail]=root;
for (int i=1;i<=m;i++)
{
int Lca=Get_Lca(st[tail],a[i]),lst=0;
while(dep[Lca]<dep[st[tail]])
{
if (lst) VT.insert(lst,st[tail]);
lst=st[tail--];
}
if (lst) VT.insert(lst,Lca);
if (dep[Lca]!=dep[st[tail]]) st[++tail]=Lca;
st[++tail]=a[i];
}
while(tail!=1) VT.insert(st[tail],st[tail-1]),tail--;
VT.dfs1(root,0);
VT.dfs2(root,0);
for (int i=1;i<=m;i++) printf("%d ",ans[i]);puts("");
VT.clear();
}
int main()
{
n=read();
for (int i=0;i<=20;i++) Lg[1<<i]=i;
for (int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
root=n+1;
add(root,1),add(1,root);
dfs(root,0);
int Q=read();
while(Q--) solve();
return 0;
}