Description
国家有一个大工程,要给一个非常大的交通网络里建一些新的通道。
我们这个国家位置非常特殊,可以看成是一个单位边权的树,城市位于顶点上。
在 2 个国家 a,b 之间建一条新通道需要的代价为树上 a,b 的最短路径。
现在国家有很多个计划,每个计划都是这样,我们选中了 k 个点,然后在它们两两之间 新建 C(k,2)条 新通道。
现在对于每个计划,我们想知道:
1.这些新通道的代价和
2.这些新通道中代价最小的是多少
3.这些新通道中代价最大的是多少
Input
第一行 n 表示点数。
接下来 n-1 行,每行两个数 a,b 表示 a 和 b 之间有一条边。
点从 1 开始标号。 接下来一行 q 表示计划数。
对每个计划有 2 行,第一行 k 表示这个计划选中了几个点。
第二行用空格隔开的 k 个互不相同的数表示选了哪 k 个点。
Output
输出 q 行,每行三个数分别表示代价和,最小代价,最大代价。
题解:建出来虚树后就不是很难了
#include<bits/stdc++.h> #define setIO(s) freopen(s".in","r",stdin), freopen(s".out","w",stdout) #define maxn 2000001 #define inf 1000000000 #define ll long long using namespace std; vector<int>G[maxn]; int edges,tim,root,top; int hd[maxn], to[maxn<<1], val[maxn<<1], nex[maxn<<1]; int dep[maxn],Top[maxn],hson[maxn],siz[maxn],dfn[maxn],fa[maxn],arr[maxn],S[maxn],mk[maxn]; inline void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=1; } void dfs1(int u,int ff) { fa[u]=ff,siz[u]=1,dep[u]=dep[ff]+1,dfn[u]=++tim; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff) continue; dfs1(v,u); siz[u]+=siz[v]; if(siz[v]>siz[hson[u]]) hson[u]=v; } } void dfs2(int u,int tp) { Top[u]=tp; if(hson[u]) dfs2(hson[u],tp); for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==fa[u]||v==hson[u]) continue; dfs2(v,v); } } inline int LCA(int x,int y) { while(Top[x]!=Top[y]) { dep[Top[x]] > dep[Top[y]] ? x = fa[Top[x]] : y = fa[Top[y]]; } return dep[x] < dep[y] ? x : y; } inline int getdis(int x,int y) { return dep[x] + dep[y] - (dep[LCA(x,y)] << 1); } inline void addvir(int u,int v) { G[u].push_back(v); } void insert(int x) { if(top<=1) { S[++top]=x; return; } int lca=LCA(x, S[top]); if(lca == S[top]) { S[++top] = x; return; } while(top > 1 && dep[S[top - 1]] >= dep[lca]) addvir(S[top - 1], S[top]), --top; if(lca != S[top]) addvir(lca, S[top]), S[top] = lca; S[++top] = x; } bool cmp(int i,int j) { return dfn[i] < dfn[j]; } ll ans=0,a1,a2; int size[maxn],d1[maxn],d2[maxn],dmin1[maxn],k,dmin2[maxn]; void DP(int x) { size[x]=mk[x]; d1[x]=d2[x]=0; if(!mk[x]) d1[x]=d2[x]=-inf; dmin1[x]=dmin2[x]=inf; if(mk[x]) dmin1[x]=0; for(int i=0;i<G[x].size();++i) { int v = G[x][i],w = dep[G[x][i]] - dep[x]; DP(v); if(mk[v]) { if(w <= dmin1[x]) dmin2[x]=dmin1[x], dmin1[x]=w; else if(w < dmin2[x]) dmin2[x]=w; } else { if(w + dmin1[v] <= dmin1[x]) dmin2[x]=dmin1[x], dmin1[x]=w + dmin1[v]; else if(w + dmin1[v] < dmin2[x]) dmin2[x] = w + dmin1[v]; } int curd=w+d1[v]; if(curd >= d1[x]) { d2[x]=d1[x], d1[x]=curd; } else if(curd > d2[x]) { d2[x] = curd; } ans+=1ll*size[v]*w*(k-size[v]),size[x]+=size[v]; } a1=max(a1, 1ll*(d1[x] + d2[x])); a2=min(a2, 1ll*(dmin1[x] + dmin2[x])); } void init(int x) { d1[x]=d2[x]=0; dmin1[x]=dmin2[x]=inf; for(int i=0;i<G[x].size();++i) init(G[x][i]); G[x].clear(); } inline void work() { scanf("%d",&k); for(int i=1;i<=k;++i) scanf("%d",&arr[i]); for(int i=1;i<=k;++i) mk[arr[i]] = 1; sort(arr+1,arr+1+k,cmp); top=S[0]=root=ans=0; if(arr[1]!=1) S[top=1]=1; for(int i=1;i<=k;++i) insert(arr[i]); while(top > 1) addvir(S[top-1], S[top]),--top; a1=-inf, a2=inf, DP(1); printf("%lld %lld %lld ",ans,a2,a1); init(1); for(int i=1;i<=k;++i) mk[arr[i]]=0; } int main() { // setIO("input"); int n; scanf("%d",&n); for(int i=1;i<n;++i) { int a,b; scanf("%d%d",&a,&b); add(a,b), add(b,a); } dfs1(1,0),dfs2(1,1); int Q; scanf("%d",&Q); for(int i=1;i<=Q;++i) work(); return 0; }