题意:
分析:
虚树板子题
首先有一个 (O(qn)) 的暴力,就是对于每一次询问, (O(n)) 的树上 DP ,我们统计一下每一个点,它的子树内离它最近/远的关键点的距离,已经关键点的个数
对于第一个询问等价于 (sum dep(x)+dep(y)-sum2 imes dep(lca))
我们 (dp) 的时候顺便统计一下每一个点作为 (lca) 出现了多少次,这个直接扫一下儿子就能得到
第二个询问按照我们 (dp) 数组记下的状态枚举一下两个子树就可以得到
我们发现这种 树上(DP) 多次询问每次给定点集(点集总和与 (n) 同阶) 的问题直接建出虚树这样每次 (dp) 的复杂度降低到和点数同阶,总的复杂度不超过 (O(nlog))
代码:
#include<bits/stdc++.h>
using namespace std;
namespace zzc
{
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
const int maxn = 1e6+5;
const int inf = 0x3f3f3f3f;
int n,idx,num,top,qt,ans2,ans3,mx[maxn],mn[maxn];
long long ans1;
int dfn[maxn],fa[maxn][22],st[maxn],p[maxn],dep[maxn],sum[maxn];
bool vis[maxn];
struct tree
{
int cnt,head[maxn];
struct edge
{
int to,nxt;
}e[maxn<<1];
void add(int u,int v)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
e[++cnt].to=u;
e[cnt].nxt=head[v];
head[v]=cnt;
}
}t1,t2;
bool cmp(int x,int y)
{
return dfn[x]<dfn[y];
}
void dfs1(int u,int ff)
{
dfn[u]=++idx;fa[u][0]=ff;dep[u]=dep[ff]+1;
for(int i=t1.head[u];i;i=t1.e[i].nxt)
{
int v=t1.e[i].to;
if(v!=ff) dfs1(v,u);
}
}
inline int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int i=21;i>=0;i--) if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if(x==y) return x;
for(int i=21;i>=0;i--)
{
if(fa[x][i]!=fa[y][i])
{
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];
}
inline void build()
{
sort(p+1,p+num+1,cmp);t2.cnt=0;
st[top=1]=1;t2.head[1]=0;
for(int i=1;i<=num;i++)
{
if(p[i]!=1)
{
int x=lca(p[i],st[top]);
if(x!=st[top])
{
while(top>1&&dfn[x]<dfn[st[top-1]]) t2.add(st[top-1],st[top]),top--;
if(top>1&&dfn[x]!=dfn[st[top-1]])
{
t2.head[x]=0;
t2.add(x,st[top]);
st[top]=x;
}
else t2.add(x,st[top--]);
}
st[++top]=p[i];t2.head[p[i]]=0;
}
}
while(top>1) t2.add(st[top-1],st[top]),top--;
}
void dfs2(int u,int ff)
{
sum[u]=0;mx[u]=-inf;mn[u]=inf;
if(vis[u]) mn[u]=0,mx[u]=0,sum[u]++;
for(int i=t2.head[u];i;i=t2.e[i].nxt)
{
int v=t2.e[i].to;
if(v!=ff)
{
dfs2(v,u);
ans1-=1ll*sum[u]*sum[v]*2*dep[u];
ans2=min(ans2,mn[u]+mn[v]-dep[u]+dep[v]);
ans3=max(ans3,mx[u]+mx[v]-dep[u]+dep[v]);
mx[u]=max(mx[u],mx[v]-dep[u]+dep[v]);
mn[u]=min(mn[u],mn[v]-dep[u]+dep[v]);
sum[u]+=sum[v];
}
}
}
inline void solve()
{
for(int i=1;i<=num;i++) ans1+=1ll*(num-1)*dep[p[i]];
dfs2(1,0);
printf("%lld %d %d
",ans1,ans2,ans3);
}
void work()
{
int a,b;
n=read();
for(int i=1;i<n;i++)
{
a=read();b=read();
t1.add(a,b);
}
dep[0]=-1;dfs1(1,0);
for(int j=1;j<=21;j++)
{
for(int i=1;i<=n;i++)
{
fa[i][j]=fa[fa[i][j-1]][j-1];
}
}
qt=read();
while(qt--)
{
ans1=0;ans2=inf;ans3=-inf;
num=read();
for(int i=1;i<=num;i++) p[i]=read(),vis[p[i]]=true;
build();
solve();
for(int i=1;i<=num;i++) vis[p[i]]=false;
}
}
}
int main()
{
zzc::work();
return 0;
}