虚树的模板题:
虚树的思想是只保留有用的点(在这道题目里面显然是标记点和lca),然后重新构建一棵树,从而使节点大大减少,优化复杂度
我们维护一条链(以1号点为根),这条链左边的所有在虚树上的位置都已经处理完毕;而这条链右边的和下面的都未处理;
这条链我们用栈来维护;
对于要新加的询问点now,对于虚树的影响有四种情况:(lc表示x与st[top]的LCA)
1.lc==st[top] : 在虚树上连接st[top]与now
.
2.lc在st[top]与st[top-1]之间;在虚树上连接lc与st[top],--top,然后now进栈;
3.lc==st[top-1] :--top,然后now进栈;
4.lc在st[top-1]之上 :在虚树上连接st[top-1]与st[top],然后退栈,重复以上步骤知道出现情况1、2、3;
在最后,我们把这条链加入到虚树中,这样一颗完美的虚树就建成了;
在每次建立虚树的时候,我们要实时清空虚树,否则时间复杂度会退化成O(n^2);
然后在这个虚树上跑dp,就可以了;
#include <bits/stdc++.h> #define int long long #define inc(i,a,b) for(register int i=a;i<=b;i++) #define dec(i,a,b) for(register int i=a;i>=b;i--) using namespace std; int head1[500010],cnt1,head2[500010],cnt2; class littlestar{ public: int to,nxt; long long w; void add1(int u,int v,long long gg){ to=v; nxt=head1[u]; head1[u]=cnt1; w=gg; } void add2(int u,int v){ to=v; nxt=head2[u]; head2[u]=cnt2; } }star1[500010*2],star2[500010*2]; int n,f[500010][25],dfn[500010],cur,dep[500010]; long long minv[500010]; void dfs(int u) { inc(i,0,19){ f[u][i+1]=f[f[u][i]][i]; } dfn[u]=++cur; for(int i=head1[u];i;i=star1[i].nxt){ int v=star1[i].to; if(!dfn[v]){ f[v][0]=u; dep[v]=dep[u]+1; minv[v]=min(minv[u],star1[i].w); dfs(v); } } } int lca(int x,int y) { if(dep[x]<dep[y]) swap(x,y); dec(i,20,0){ if(dep[f[x][i]]>=dep[y]) x=f[x][i]; } if(x==y) return x; dec(i,20,0){ if(f[x][i]!=f[y][i]){ x=f[x][i]; y=f[y][i]; } } return f[x][0]; } bool cmp(int x,int y){return dfn[x]<dfn[y];} int num; int judge[500001],h[500010]; int st[500010],top; int dp(int u) { long long summ=0; for(int i=head2[u];i;i=star2[i].nxt){ int v=star2[i].to; summ+=dp(v); } long long ans=0; if(judge[u]) ans=minv[u]; else ans=min(minv[u],summ); judge[u]=0; head2[u]=0; return ans; } signed main() { cin>>n; inc(i,1,n-1){ int u,v; long long w; scanf("%lld%lld%lld",&u,&v,&w); star1[++cnt1].add1(u,v,w); star1[++cnt1].add1(v,u,w); } minv[1]=1e17+21; dfs(1); int q; scanf("%lld",&q); while(q--){ scanf("%lld",&num); inc(i,1,num){ scanf("%lld",&h[i]); judge[h[i]]=1; } sort(h+1,h+1+num,cmp); top=1; st[top]=h[1]; inc(i,2,num){ int LCA=lca(h[i],st[top]); while(true){ if(dep[LCA]>=dep[st[top-1]]){ if(LCA!=st[top]){ star2[++cnt2].add2(LCA,st[top]); if(LCA!=st[top-1]){ st[top]=LCA; } else{ --top; } } else{ break; } } else{ star2[++cnt2].add2(st[top-1],st[top]); --top; } } st[++top]=h[i]; } while(--top){ star2[++cnt2].add2(st[top],st[top+1]); } cout<<dp(st[1])<<endl; cnt2=0; } }