又是一道虚树好题啊
我们建出来虚树,然后考虑dp过程,我们分别令(sum[x],mndis[x],mxdis[x],size[x])为子树内的路径长度和,最短链,最长链,子树内关键点个数。
对于一个关键点,首先他的(size=1,mndis=0)
我们考虑怎么合并,首先我们可以直接维护三个值表示最终的答案。如果说当前的点(size[x]>0),那么我们就可以每次用他和新的子树进行更新ans,然后合并
QWQ
其实合并就差不多类似的方式
主要是(sum)合并的时候,你要用(size[x]*size[p]*val[i])当前这条边会被算的贡献,(size[x]*sum[p]),每条(p)的路径,都会到(x)中的所有关键点,(size[p]*sum[x])这个也是同理
其他的应该也就差不多了
直接上代码
// luogu-judger-enable-o2
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 3e6+1e2;
const int maxm = 4e6+1e2;
const int inf = 1e18;
int point[maxn],nxt[maxm],to[maxm];
int deep[maxn],dfn[maxn];
int sum[maxn];
int n,m,tot,top,cnt;
int f[1010000][21];
int k,size[maxn],tag[maxn];
int s[maxn];
int mxdis[maxn],mndis[maxn];
int mx,mn,ymh;
int a[maxn];
int val[maxn];
void addedge(int x,int y,int w)
{
//cout<<"***"<<" "<<x<<" "<<y<<endl;
nxt[++cnt]=point[x];
to[cnt]=y;
val[cnt]=w;
point[x]=cnt;
}
void dfs(int x,int fa,int dep)
{
deep[x]=dep;
dfn[x]=++tot;
for (int i=point[x];i;i=nxt[i])
{
int p = to[i];
if (p==fa) continue;
f[p][0]=x;
dfs(p,x,dep+1);
}
}
void init()
{
for (int j=1;j<=20;j++)
for (int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1];
}
int go_up(int x,int d)
{
for (int i=0;i<=20;i++)
{
if (d&(1<<i))
{
x=f[x][i];
}
}
return x;
}
int lca(int x,int y)
{
if (deep[x]>deep[y]) x=go_up(x,deep[x]-deep[y]);
else y=go_up(y,deep[y]-deep[x]);
if (x==y) return x;
for (int i=20;i>=0;i--)
{
if (f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
bool cmp(int a,int b)
{
return dfn[a]<dfn[b];
}
void solve()
{
// memset(point,0,sizeof(point));
sort(a+1,a+1+k,cmp);
cnt=0;
top=1;
s[top]=1;
for (int i=1;i<=k;i++)
{
int l = lca(s[top],a[i]);
if (l!=s[top])
{
while (top>1)
{
if (dfn[s[top-1]]>dfn[l])
{
addedge(s[top-1],s[top],deep[s[top]]-deep[s[top-1]]);
top--;
}
else
{
if (dfn[s[top-1]]==dfn[l])
{
addedge(s[top-1],s[top],deep[s[top]]-deep[s[top-1]]);
top--;
break;
}
else
{
addedge(l,s[top],deep[s[top]]-deep[l]);
s[top]=l;
break;
}
}
}
}
if (s[top]!=a[i]) s[++top]=a[i];
}
while (top>1)
{
addedge(s[top-1],s[top],deep[s[top]]-deep[s[top-1]]);
top--;
}
}
void dp(int x,int flag)
{
mndis[x]=inf;
mxdis[x]=0;
size[x]=0;
sum[x]=0;
if (tag[x]==flag)
{
size[x]=1;
mndis[x]=0;
}
for (int &i=point[x];i;i=nxt[i])
{
int p = to[i];
int now = val[i];
dp(p,flag);
if (size[x]>0)
{
ymh=ymh+size[p]*size[x]*now+size[p]*sum[x]+size[x]*sum[p];
mx=max(mx,mxdis[x]+mxdis[p]+now);
mn=min(mn,mndis[x]+mndis[p]+now);
}
mndis[x]=min(mndis[x],mndis[p]+now);
mxdis[x]=max(mxdis[p]+now,mxdis[x]);
sum[x]+=sum[p]+now*size[p];
size[x]+=size[p];
}
//cout<<mndis[x]<<" "<<mxdis[x]<<" "<<size[x]<<" "<<sum[x]<<endl;
}
signed main()
{
n=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read();
addedge(x,y,0);
addedge(y,x,0);
}
dfs(1,0,1);
init();
memset(point,0,sizeof(point));
m=read();
for (int i=1;i<=m;i++)
{
ymh=0;
mn=inf;
mx=0;
k=read();
for (int j=1;j<=k;j++) a[j]=read(),tag[a[j]]=i;
solve();
dp(1,i);
cout<<ymh<<" "<<mn<<" "<<mx<<"
";
}
return 0;
}