虚树dp
#include<cstdio> #include<vector> #include<iostream> #include<algorithm> const int maxn = 1000500; typedef long long ll; int n,m,k; struct T{ int to,nxt; ll v; }way[maxn<<1]; int h[maxn],num; inline void adde(int x,int y,ll v){ way[++num]={y,h[x],v},h[x]=num; way[++num]={x,h[y],v},h[y]=num; } int dep[maxn],size[maxn],fa[maxn],dfn[maxn],vis[maxn],tot; int bz[maxn][20],mn[maxn][20]; inline void dfs(int x){ size[x]=vis[x]=1,dep[x]=dep[fa[x]]+1,dfn[x]=++tot; for(int i=1;i<20;++i) bz[x][i]=bz[bz[x][i-1]][i-1],mn[x][i]=std::min(mn[x][i-1],mn[bz[x][i-1]][i-1]); for(int i=h[x];i;i=way[i].nxt) if(!vis[way[i].to]) *bz[way[i].to]=fa[way[i].to]=x,*mn[way[i].to]=way[i].v,dfs(way[i].to),size[x]+=size[way[i].to]; vis[x]=0; } inline int lca(int x,int y){ if(dep[x]>dep[y])std::swap(x,y); for(int d=dep[y]-dep[x];d;d&=d-1)y=bz[y][__builtin_ctz(d)]; if(x==y)return x; for(int i=19;~i;--i)if(bz[x][i]!=bz[y][i])x=bz[x][i],y=bz[y][i]; return *bz[x]; } inline void down(int&x,int y){if(x>y)x=y;} inline int get(int x,int y){ if(dep[x]>dep[y])std::swap(x,y); int ans=1e9; for(int d=dep[y]-dep[x];d;d&=d-1)down(ans,mn[y][__builtin_ctz(d)]),y=bz[y][__builtin_ctz(d)]; if(x!=y)exit(1); return ans; } ll f[maxn]; inline void dp(int x){ vis[x]=1; for(int i=h[x];i;i=way[i].nxt) if(!vis[way[i].to]) dp(way[i].to),f[x] += std::min(way[i].v,f[way[i].to]); vis[x]=0; } int st[maxn],tp; inline void solve(){ std::vector<int> v; static const auto cmp = [&](const int&x,const int&y){return dfn[x]<dfn[y];}; std::cin >> k; for(int i=1,x;i<=k;++i)std::cin >> x,v.push_back(x),f[x]=1e13; std::sort(v.begin(),v.end(),cmp); for(int i=0;i+1<k;++i) v.push_back(lca(v[i],v[i+1])); v.push_back(1); std::sort(v.begin(),v.end(),cmp); v.erase(std::unique(v.begin(),v.end()),v.end()); tp=0; for(int i:v){ while(tp && dfn[st[tp]]+size[st[tp]]-1 < dfn[i])--tp; if(st[tp])adde(st[tp],i,get(i,st[tp])); st[++tp]=i; } dp(1); std::cout << f[1] << ' '; for(int i:v)h[i]=0,f[i]=0; num=0; } int main(){ std::ios::sync_with_stdio(false),std::cin.tie(0); std::cin >> n; for(int i=1,x,y,v;i<n;++i)std::cin >> x >> y >> v,adde(x,y,v); dfs(1),num=0; for(int i=1;i<=n;++i)h[i]=0; std::cin >> m; while(m--)solve(); }