sol
很显然的虚树DP呀。
树上任意两点距离之和?其实只要考虑每一条边被计算了多少次即可,若这条边下方的关键点(也就是选出的那些点)数量为(i),那么这条边的计算次数就是(i*(k-i))。
然后最大最小值,直接对每个点记子树中所有关键点到它的最长/最短距离即可。注意初值与这个点是不是关键点有关。
我一开始很傻很天真地以为最短距离一定是dfs序相邻的两个点然后就。。。
总体上还是挺好写的。
code
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define ll long long
int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 4e6+5;
int n,Q,to[N],nxt[N],val[N],head[N],cnt;
int fa[N],dep[N],sz[N],son[N],top[N],dfn[N],low[N];
int k,len,tp,s[N],q[N],mark[N],f[N],g[N];
int Max,Min;ll Sum;
void link(int u,int v,int w){to[++cnt]=v;nxt[cnt]=head[u];val[cnt]=w;head[u]=cnt;}
void dfs1(int u,int f)
{
fa[u]=f;dep[u]=dep[f]+1;sz[u]=1;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (v==f) continue;
dfs1(v,u);
sz[u]+=sz[v];if (sz[v]>sz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int up)
{
top[u]=up;dfn[u]=++cnt;
if (son[u]) dfs2(son[u],up);
for (int e=head[u];e;e=nxt[e])
if (to[e]!=fa[u]&&to[e]!=son[u])
dfs2(to[e],to[e]);
low[u]=cnt;
}
int getlca(int u,int v)
{
while (top[u]^top[v])
{
if (dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
}
return dep[u]<dep[v]?u:v;
}
bool cmp_dfn(int u,int v){return dfn[u]<dfn[v];}
void dp(int u)
{
sz[u]=mark[u]?1:0;
f[u]=mark[u]?0:-1e9;
g[u]=mark[u]?0:1e9;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];dp(v);
sz[u]+=sz[v];Sum+=1ll*sz[v]*(k-sz[v])*val[e];
Max=max(Max,f[u]+f[v]+val[e]);
f[u]=max(f[u],f[v]+val[e]);
Min=min(Min,g[u]+g[v]+val[e]);
g[u]=min(g[u],g[v]+val[e]);
}
}
int main()
{
n=gi();
for (int i=1;i<n;++i)
{
int u=gi(),v=gi();
link(u,v,0);link(v,u,0);
}
dfs1(1,0);cnt=0;dfs2(1,1);
Q=gi();memset(head,0,sizeof(head));
while (Q--)
{
k=len=gi();tp=cnt=0;
for (int i=1;i<=k;++i) mark[s[i]=gi()]=1;
sort(s+1,s+k+1,cmp_dfn);
for (int i=1;i<k;++i) s[++len]=getlca(s[i],s[i+1]);
sort(s+1,s+len+1,cmp_dfn);len=unique(s+1,s+len+1)-s-1;
for (int i=1;i<=len;++i)
{
while (tp&&low[q[tp]]<dfn[s[i]]) --tp;
link(q[tp],s[i],dep[s[i]]-dep[q[tp]]);
q[++tp]=s[i];
}
Sum=Max=0;Min=1e9;
dp(s[1]);
printf("%lld %d %d
",Sum,Min,Max);
for (int i=1;i<=len;++i) mark[s[i]]=head[s[i]]=0;
}
return 0;
}