虚树学习笔记
[SDOI2011]消耗战
题意
给一棵(n)个点,带边权的树。
给(m)组询问,每组有(k_i)个关键点,你需要切断一些边,使得每个点都到不了根节点,求最小代价。
(n <= 2.5 cdot 10^5, m <= 5 cdot 10^5,sum k_i <= 5 cdot 10^5)
Solve 1
对于每组询问,做一个(dp),设(f[x])表示切断(x)和他的子树所需最小代价,转移分两种
- (x)是关键点,答案为(x)到根路径最小值
- (x)不是关键点,答案为切断所有儿子的值和第(1)种取(min)
复杂度(O(nm))
(can we do better?)
Solve 2
要用到所讲的虚树。
我们发现转移过程中,对于转移有贡献的只有关键点以及他们之间的祖先,于是我们可以简化树的结构。
把关键点按(dfs)序排序,相邻两个求出(lca)并建边。最后在虚树上做(dp),复杂度(O(n log n + sum k_i log n))
具体实现用一个栈维护一条树链,排序后一次加入点。
设当前加入的点(u)
- 如果(top <=1) ,(stk[++top] = u)
- 设(l = lca(u,stk[top])),如果(l == stk[top]),那么(u)应该接在(stk[top])底下,(stk[++top] = u)
- 否则说明(u)已经是一个新的子树,持续弹栈直到(dfn[stk[top-1]] < dfn[l] <= dfn[stk[top]]),如果(l != stk[top]),把(stk[top])接在(l)后面,(stk[top] = l),最后(stk[++top] = u)
void insert(int u){
if(top <= 1) return stk[++top] = u,void();
int l = lca(u,stk[top]);
if(l == stk[top]) return stk[++top] = u,void();
while(top > 1 && dfn[l] <= dfn[stk[top-1]]){
add(stk[top-1],stk[top]); top--;
}
if(l != stk[top]) add(l,stk[top]),stk[top] = l;
stk[++top] = u;
return ;
}
Code
#include<bits/stdc++.h>
#define int long long
#define N 1000015
#define rep(i,a,n) for (int i=a;i<=n;i++)
#define per(i,a,n) for (int i=n;i>=a;i--)
#define inf 0x3f3f3f3f3f3f3f3f
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define fi first
#define se second
#define lowbit(i) ((i)&(-i))
#define VI vector<int>
#define all(x) x.begin(),x.end()
using namespace std;
int n,m,a[N],Min[N],k,dfn[N],clk;
vector<pii> e[N];
VI g[N];
void dfs(int u,int fa){
dfn[u] = ++clk;
for(auto I:e[u]){
int v = I.fi,w = I.se;
if(v == fa) continue;
Min[v] = min(Min[u],w);
dfs(v,u);
}
}
bool cmp(int u,int v){
return dfn[u] < dfn[v];
}
namespace LCA{
int fa[N][24],dep[N];
void Dfs(int u,int f){
fa[u][0] = f; dep[u] = dep[f]+1;
for(auto I:e[u]){
int v = I.fi;
if(v == f) continue;
Dfs(v,u);
}
}
void init(){
rep(j,1,21){
rep(i,1,n){
fa[i][j] = fa[fa[i][j-1]][j-1];
}
}
}
int lca(int u,int v){
if(dep[u] < dep[v]) swap(u,v);
int t = dep[u] - dep[v];
per(i,0,21){
if((1<<i)&t) u = fa[u][i];
}
if(u == v) return u;
per(i,0,21){
if(fa[u][i] != fa[v][i]) u = fa[u][i],v = fa[v][i];
}
return fa[u][0];
}
}
using namespace LCA;
int stk[N],top;
void add(int u,int v){
//printf("%lld -> %lld
",u,v);
g[u].pb(v);
}
void insert(int u){
if(top <= 1) return stk[++top] = u,void();
int l = lca(u,stk[top]);
if(l == stk[top]) return stk[++top] = u,void();
while(top > 1 && dfn[l] <= dfn[stk[top-1]]){
add(stk[top-1],stk[top]); top--;
}
if(l != stk[top]) add(l,stk[top]),stk[top] = l;
stk[++top] = u;
return ;
}
void build(){
top = 0;
stk[++top] = 1;
rep(i,1,k) insert(a[i]);
while(top > 1) add(stk[top-1],stk[top]),top--;
}
bool gkp[N];
int dp(int u){
int res = 0;
if(g[u].size() == 0){
//printf("u: %lld val: %lld
",u,Min[u]);
return Min[u];
}
for(auto v:g[u]){
res += dp(v);
}
g[u].clear();
if(!gkp[u]) return min(res,Min[u]);
//printf("u: %lld val: %lld
",u,res);
return Min[u];
}
signed main(){
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
scanf("%lld",&n);
memset(Min,0x3f,sizeof Min);
rep(i,2,n){
int u,v,w; scanf("%lld%lld%lld",&u,&v,&w);
e[u].pb(mp(v,w)); e[v].pb(mp(u,w));
}
dfs(1,0);
// rep(i,1,n) printf("%lld ", Min[i]);
// printf("
");
Dfs(1,0); init();
// rep(i,1,n){
// rep(j,i+1,n){
// printf("(i,j): (%lld,%lld) lca: %lld
",i,j,lca(i,j));
// }
// }
scanf("%lld",&m);
rep(_,1,m){
scanf("%lld",&k); rep(i,1,k) scanf("%lld",&a[i]),gkp[a[i]] = 1;
sort(a+1,a+k+1,cmp); //puts("sort finished");
build(); //puts("build finished");
printf("%lld
",dp(1));
rep(i,1,k) gkp[a[i]] = 0;
}
return 0;
}