给定一棵 (n) 个点的以 (1) 为根的有根树,在 (m) 条给出的带权祖-孙路径中选出若干条使之覆盖所有边,最小化权值之和。
(n, m leq 3 imes 10^5)。
线段树合并
看到题先想到了 [NOI2020]命运,就想了类似的做法。
(O(n^2)) 的 dp 很好想:设 (f[x][d]) 表示覆盖了 (x) 的子树且选择的路径中端点深度最小值为 (d) 的最小权值之和,这个是个 min 卷积,像NOI题那样拿线段树合并优化即可。
(不过写法不对会MLE,为此贡献了一页的提交记录)
#include <bits/stdc++.h>
#define perr(a...) fprintf(stderr, a)
#define dbg(a...) 42 //perr(" 33[32;1m"), perr(a), perr(" 33[0m")
template <class T, class U>
inline bool smin(T &x, const U &y) {
return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
return x < y ? x = y, 1 : 0;
}
using LL = long long;
using PII = std::pair<int, int>;
constexpr int N(3e5 + 5);
int n, m;
std::vector<int> g[N];
std::vector<PII> p[N];
struct Node {
Node *ls, *rs;
LL min, tag;
Node() : ls(nullptr), rs(nullptr), min(LLONG_MAX), tag(0) {}
void add(LL x) { min += x, tag += x; }
void pushup() {
min = LLONG_MAX;
if (ls) smin(min, ls->min);
if (rs) smin(min, rs->min);
}
void pushdown() {
if (tag) {
if (ls) ls->add(tag);
if (rs) rs->add(tag);
tag = 0;
}
}
} *root[N];
void ins(Node *&o, int l, int r, int x, LL y) {
if (!o) o = new Node;
smin(o->min, y);
if (l == r) return;
o->pushdown();
int m = l + r >> 1;
x <= m ? ins(o->ls, l, m, x, y) : ins(o->rs, m + 1, r, x, y);
}
void trash(Node *o) {
if (!o) return;
trash(o->ls), trash(o->rs);
delete o;
}
void del(Node *&o, int l, int r, int x, int y) {
if (!o || x > r || y < l) return;
if (x <= l && r <= y) {
trash(o);
o = nullptr;
return;
}
o->pushdown();
int m = l + r >> 1;
del(o->ls, l, m, x, y);
del(o->rs, m + 1, r, x, y);
if (o->ls || o->rs) {
o->pushup();
} else {
delete o;
o = nullptr;
}
}
Node *merge(Node *x, Node *y, int l, int r, LL xr = LLONG_MAX, LL yr = LLONG_MAX) {
if (!x) {
if (!y) return nullptr;
if (xr == LLONG_MAX) return trash(y), nullptr;
y->add(xr);
return y;
}
if (!y) {
if (yr == LLONG_MAX) return trash(x), nullptr;
x->add(yr);
return x;
}
if (l == r) {
assert(l == r);
x->min += std::min(y->min, yr);
if (xr < LLONG_MAX) smin(x->min, y->min + xr);
delete y;
return x;
}
int m = l + r >> 1;
x->pushdown(), y->pushdown();
LL nxr = xr, nyr = yr;
if (x->rs) smin(nxr, x->rs->min);
if (y->rs) smin(nyr, y->rs->min);
x->ls = merge(x->ls, y->ls, l, m, nxr, nyr);
x->rs = merge(x->rs, y->rs, m + 1, r, xr, yr);
x->pushup();
delete y;
return x;
}
LL ask(Node *o, int l, int r, int x, int y) {
if (!o || x > r || y < l) return LLONG_MAX;
if (x <= l && r <= y) return o->min;
int m = l + r >> 1;
o->pushdown();
return std::min(ask(o->ls, l, m, x, y), ask(o->rs, m + 1, r, x, y));
}
int dep[N], max_dep;
void dfs0(int x, int fa) {
dep[x] = dep[fa] + 1;
smax(max_dep, dep[x]);
for (int y : g[x]) {
if (y == fa) continue;
dfs0(y, x);
}
}
void dfs(int x, int fa) {
if (x > 1 && g[x].size() == 1) {
for (auto &v : p[x]) {
ins(root[x], 1, max_dep, dep[v.first], v.second);
}
if (!root[x]) {
std::cout << "-1
";
exit(0);
}
return;
}
for (int y : g[x]) {
if (y == fa) continue;
dfs(y, x);
root[x] = root[x] ? merge(root[x], root[y], 1, max_dep) : root[y];
}
for (auto &v : p[x]) {
LL s = ask(root[x], 1, max_dep, dep[v.first], dep[x]);
if (s < LLONG_MAX) ins(root[x], 1, max_dep, dep[v.first], s + v.second);
}
del(root[x], 1, max_dep, dep[x] + (x == 1), max_dep);
if (!root[x]) {
std::cout << "-1
";
exit(0);
}
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin >> n >> m;
if (n == 1) return puts("0"), 0;
for (int i = 1, x, y; i < n; i++) {
std::cin >> x >> y;
g[x].push_back(y), g[y].push_back(x);
}
while (m--) {
int x, y, z;
std::cin >> x >> y >> z;
if (x == y) continue;
p[x].emplace_back(y, z);
}
dfs0(1, 0);
dfs(1, 0);
std::cout << root[1]->min << "
";
return 0;
}
可并堆
仔细考虑一下发现我们并不需要知道覆盖完子树后向上伸出来具体多少。
用小根堆存一下覆盖 (x) 子树的权值和,合并两个堆分别整体加上另一个堆的最小值就行了。
加入一个路径的时候也只要取最小的相加。
#include <bits/stdc++.h>
template <class T, class U>
inline bool smin(T &x, const U &y) {
return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
return x < y ? x = y, 1 : 0;
}
using LL = long long;
using PII = std::pair<int, int>;
constexpr int N(3e5 + 5);
int n;
std::vector<int> g[N];
std::vector<PII> p[N];
struct Node {
Node *ls, *rs;
LL val, tag;
int dep;
Node(LL v, int d) : ls(nullptr), rs(nullptr), val(v), tag(0), dep(d) {}
void add(LL x) {
val += x;
tag += x;
}
void pushdown() {
if (tag) {
if (ls) ls->add(tag);
if (rs) rs->add(tag);
tag = 0;
}
}
} *root[N];
Node *merge(Node *x, Node *y) {
if (!x) return y;
if (!y) return x;
static std::mt19937 rnd(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if (x->val > y->val) std::swap(x, y);
x->pushdown(), y->pushdown();
if (rnd() & 1) {
x->rs = merge(x->rs, y);
} else {
x->ls = merge(x->ls, y);
}
return x;
}
int dep[N];
void update(int x) {
while (root[x] && root[x]->dep >= dep[x] + (x == 1)) {
root[x]->pushdown();
root[x] = merge(root[x]->ls, root[x]->rs);
}
if (!root[x]) {
puts("-1");
exit(0);
}
}
void dfs(int x, int fa) {
dep[x] = dep[fa] + 1;
if (x > 1 && g[x].size() == 1) {
for (auto [f, v] : p[x]) {
root[x] = merge(root[x], new Node(v, dep[f]));
}
update(x);
return;
}
for (int y : g[x]) {
if (y == fa) continue;
dfs(y, x);
if (root[x]) {
LL v = root[x]->val;
root[x]->add(root[y]->val);
root[y]->add(v);
}
root[x] = merge(root[x], root[y]);
}
for (auto [f, v] : p[x]) {
root[x] = merge(root[x], new Node(v + root[x]->val, dep[f]));
}
update(x);
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int m;
std::cin >> n >> m;
if (n == 1) return puts("0"), 0;
for (int i = 1, x, y; i < n; i++) {
std::cin >> x >> y;
g[x].push_back(y), g[y].push_back(x);
}
while (m--) {
int x, y, z;
std::cin >> x >> y >> z;
if (x == y) continue;
p[x].emplace_back(y, z);
}
dfs(1, 0);
std::cout << root[1]->val << "
";
return 0;
}