题解:
首先建虚树是显然的
然后在上面dp
注意到有的点可能是真实不存在的
所以要认真的搞一下dp
首先分析一波
最长链是任意的
因为假如a->b b->c (a,c真实存在,b是假的)
那么a->c 一定大于b->a or c
所以这个跟求直径一样max_len1 max_len2就可以了
最短链要求是要在两个真实点之间的
所以令min_len[i]表示i的子树中真实存在的点到它的最短距离
(注意到叶子节点一定是)
求路径和就记录子树中到当前点的路径和,子树中真实点的数目就可以了
令max_len1表示
代码:
#include <bits/stdc++.h> using namespace std; #define ll long long #define N 2000100 struct re{ int a,b; ll c; }a[N],a2[N]; int head[N],bz[N][20],bz2[N][20],dep[N],dfn[N],cnt,l,l2,head2[N],n,m,ansmin,ansmax; int max_len1[N],max_len2[N],min_len[N],st[N],b[N]; ll sum[N],sum2[N],num[N]; bool ff[N]; #define INF 1e9 void arr(ll x,ll y) { a[++l].a=head[x]; a[l].b=y; head[x]=l; } void dfs(ll x,ll father) { dfn[x]=++cnt; dep[x]=dep[father]+1; bz[x][0]=father; bz2[x][0]=1; ll u=head[x]; while (u) { ll v=a[u].b; if (v!=father) dfs(v,x); u=a[u].a; } } ll lca(ll x,ll y) { if (dep[x]<dep[y]) swap(x,y); for (ll i=19;i>=0;i--) if (dep[bz[x][i]]>=dep[y]) x=bz[x][i]; if (x==y) return(x); for (ll i=19;i>=0;i--) if (bz[x][i]!=bz[y][i]) { x=bz[x][i]; y=bz[y][i]; } return(bz[x][0]); } ll query(ll x,ll y) { if (dep[x]<dep[y]) swap(x,y); ll ans=0; for (ll i=19;i>=0;i--) if (dep[bz[x][i]]>=dep[y]) ans+=bz2[x][i],x=bz[x][i]; if (x==y) return(ans); for (ll i=19;i>=0;i--) if (bz[x][i]!=bz[y][i]) { ans+=bz2[x][i]+bz2[y][i]; x=bz[x][i]; y=bz[y][i]; } return(ans+bz2[x][0]+bz2[y][0]); } bool cmp(ll x,ll y) { return(dfn[x]<dfn[y]); } ll k; queue<ll> q; void arr2(ll x,ll y) { q.push(x); a2[++l2].a=head2[x]; a2[l2].b=y; if (x==n+1||y==n+1) a2[l2].c=0; else a2[l2].c=query(x,y); head2[x]=l2; } void js(ll x,ll fa) { ll u=head2[x]; if (ff[x]) num[x]=1,min_len[x]=0; while (u) { ll v=a2[u].b; if (v!=fa) { js(v,x); num[x]+=num[v]; if (max_len1[v]+a2[u].c>=max_len1[x]) { max_len2[x]=max_len1[x]; max_len1[x]=max_len1[v]+a2[u].c; } else if (max_len1[v]+a2[u].c>max_len2[x]) max_len2[x]=max_len1[v]+a2[u].c; if (num[v]>0) { ansmin=min(ansmin,min_len[x]+min_len[v]+int(a2[u].c)); min_len[x]=min(min_len[x],min_len[v]+int(a2[u].c)); } } u=a2[u].a; } ansmax=max(ansmax,max_len1[x]+max_len2[x]); u=head2[x]; while (u) { ll v=a2[u].b; if (v!=fa) { sum[x]+=(num[x]-num[v])*num[v]*a2[u].c+sum2[v]*(num[x]-num[v])+sum[v]; sum2[x]+=sum2[v]+a2[u].c*num[v]; } u=a2[u].a; } } void solve() { ll top=0; st[++top]=n+1; while (!q.empty()) { ll x=q.front(); head2[x]=0,ff[x]=0,sum[x]=0,num[x]=0,sum2[x]=0; max_len1[x]=0,max_len2[x]=0,min_len[x]=INF,q.pop(); } l2=0; ansmax=0; ansmin=INF; for (ll i=1;i<=k;i++) { ff[b[i]]=1; ll tmp=lca(b[i],st[top]); while (true) { if (dfn[tmp]>=dfn[st[top-1]]) { if (tmp!=st[top]) arr2(st[top],tmp),arr2(tmp,st[top]); top--; if (tmp!=st[top]) st[++top]=tmp; break; } else { arr2(st[top-1],st[top]); arr2(st[top],st[top-1]); top--; } } if (st[top]!=b[i]) st[++top]=b[i]; } while (top>1) { arr2(st[top-1],st[top]); arr2(st[top],st[top-1]); top--; } js(n+1,0); cout<<sum[n+1]<<" "<<ansmin<<" "<<ansmax<<endl; } int main() { cin>>n; arr(n+1,1); ll x,y; for (ll i=1;i<=n-1;i++) cin>>x>>y,arr(x,y),arr(y,x); dfs(n+1,0); for (ll i=1;i<=19;i++) for (ll j=1;j<=n+1;j++) { bz[j][i]=bz[bz[j][i-1]][i-1]; bz2[j][i]=bz2[j][i-1]+bz2[bz[j][i-1]][i-1]; } ll q; cin>>q; for (ll i=1;i<=n+10;i++) min_len[i]=INF; for (ll i=1;i<=q;i++) { cin>>k; ll len=0; for (ll j=1;j<=k;j++) { cin>>x; b[++len]=x; } sort(b+1,b+1+k,cmp); solve(); } return 0; }