虚树学习笔记
以消耗战为例
显然可以树形dp, 但时间复杂度爆炸
观察发现(sum k)的值不是很大,假设只有两个点x, y,它们的公共祖先lca, 树形dp就像分别枚举割断它们到lca的每一条边,事实上我们一下子(ans = min(mn[lca], mn[x] + mn[y]))就可以算出来,这是因为他们之间有大量的无用的点
所以建一棵虚树来保留对答案可能有影响的关键点,询问点和一些lca
具体来说就是类似维护极右链似的, 每次把lca搞到栈里,还是看代码吧
还需要多做些题理解一下
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
#define ll long long
using namespace std;
const int N = 1005000;
template <typename T>
void read(T &x) {
x = 0; bool f = 0;
char c = getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48);
if (f) x=-x;
}
int n, m;
int h[N], ne[N], to[N];
int w[N], dep[N], tot;
inline void add(int x, int y, int z) {
ne[++tot] = h[x], to[tot] = y;
w[tot] = z, h[x] = tot;
}
int siz[N], f[N], son[N];
int Top[N], a[N], s[N], top;
ll mn[N];
void dfs1(int x, int fa) {
siz[x] = 1, f[x] = fa, dep[x] = dep[fa] + 1;
for (int i = h[x]; i; i = ne[i]) {
int y = to[i]; if (y == fa) continue;
mn[y] = min(mn[x], (ll)w[i]);
dfs1(y, x), siz[x] += siz[y];
if (siz[y] > siz[son[x]]) son[x] = y;
}
}
int dfn[N], num;
void dfs2(int x, int topf) {
Top[x] = topf, dfn[x] = ++num;
if (!son[x]) return;
dfs2(son[x], topf);
for (int i = h[x]; i; i = ne[i])
if (!dfn[to[i]]) dfs2(to[i], to[i]);
}
int Lca(int x, int y) {
while (Top[x] != Top[y]) {
if (dep[Top[x]] < dep[Top[y]]) swap(x, y);
x = f[Top[x]];
}
return dep[x] < dep[y] ? x : y;
}
bool cmp(int a, int b) {
return dfn[a] < dfn[b];
}
vector <int> v[N];
inline void add_e(int x,int y) {
v[x].push_back(y);
}
void ins(int x) {
if (top == 1) return (void)(s[++top] = x);
int lca = Lca(x, s[top]);
if (lca == s[top]) return;
while (top > 1 && dfn[s[top-1]] >= dfn[lca]) add_e(s[top-1], s[top]), top--;
if (lca != s[top]) add_e(lca, s[top]), s[top] = lca;
s[++top] = x;
}
ll dp(int x) {
if (!v[x].size()) return mn[x];
ll sum = 0;
for (int i = 0;i < v[x].size(); i++) sum += dp(v[x][i]);
v[x].clear(); return min(sum, mn[x]);
}
int main() {
read(n);
for (int i = 1;i < n; i++) {
int x, y, z; read(x), read(y), read(z);
add(x, y, z); add(y, x, z);
}
mn[1] = 1ll << 50, dfs1(1, 0), dfs2(1, 1);
read(m);
while (m--) {
int k; read(k);
for (int i = 1;i <= k; i++) read(a[i]);
sort(a + 1, a + k + 1, cmp);
s[top = 1] = 1;
for (int i = 1;i <= k; i++) ins(a[i]);
while (top > 0) add_e(s[top - 1], s[top]), top--;
printf ("%lld
", dp(1));
}
return 0;
}