Description
给定一颗 (n) 个结点的树,每个点有一个点权 (v)。点权只可能为 (0) 或 (1)。
现有一个空数列,每次可以向数列尾部添加一个点 (i) 的点权 (v_i),但必须保证此时 (i) 没有父结点。添加后将 (i) 删除。
这样可以一个长为 (n) 的数列 (x)。求 (x) 中逆序对数的最小值。
Hint
- (1le nle 2 imes 10^5)
- (v_i in {0, 1})
Solution
由于一个结点的父结点尚未被删除,那么现在该结点则无法被加入数列。可见题目要求我们 从树根自顶向下 删除。
但显然我们不会这样做——我们 将所有结点视作独立,向父亲方向合并。
我们不妨先考虑这样一个问题:对于一个根结点为 (x) 的树,其子结点为 (y_1, y_2, cdots y_k)。假设子树 (y_1, y_2, cdots y_k) 都已经合并好了,那么我们只要将这些子树合并答案,向上传答案即可。
首先,由题意得,结点 (x) 的点权必须排在最前面。接下来就需要合理安排顺序,使得 跨越子树的逆序对 数量最小。由于子树内在前期早已统计完毕,此处无需再做讨论。
为方便讨论,在这里我们还需维护子树中 (0, 1) 的个数,分别记为 ( ext{cnt}(cdots, 0), ext{cnt}(cdots, 1))。
若要使逆序对尽可能小,而权值就只有 (0, 1),第一直觉就是 贪心地把 (0) 尽量排前面。
但直觉是很模糊的,我们需要一个明确的标准。
对于两个子树 (y_i, y_j),如果 (y_i) 排在前面,那么会产生 ( ext{cnt}(y_i, 1) imes ext{cnt}(y_j, 0)) 个逆序对,反正则会产生 ( ext{cnt}(y_j, 1) imes ext{cnt}(y_i, 0)) 个。
显然我们应选择结果较少的策略——优先选取 (dfrac{ ext{cnt}(y, 0)}{ ext{cnt}(y, 1)}) 较小的。为避免除以零造成 RE,需要化除为乘。
但此题不能直接递归处理,需要全局一起算,即上文中“将所有结点视作独立,向父亲方向合并”的思路。
那么子树的 ( ext{cnt}) 值就变成了 连通块 的 ( ext{cnt}) 值,容易发现上面的贪心思路于此仍然有效。
此处涉及连通块整块信息的维护,不难想到 并查集。连通块的有序维护,可以使用 堆。
在每个点向上合并后,父亲方向结点需要删去,这对于堆来说就不太方便(当然可以考虑 multiset 或 可删堆)
但其实不用这么麻烦:直接根据 ( ext{cnt}) 值判断是否已经被合并然后选择性跳过即可。
最后做到 1 号点就不用重新插入堆中了。
Code
/*
* Author : _Wallace_
* Source : https://www.cnblogs.com/-Wallace-/
* Problem : AtCoder AGC023F 01 on Tree
*/
#include <algorithm>
#include <iostream>
#include <queue>
using namespace std;
const int N = 2e5 + 5;
int n, fa[N], dsu[N];
int cnt[N][2];
struct item {
int c0, c1, idx;
bool operator < (const item& t) const {
return c0 * 1ll * t.c1 < c1 * 1ll * t.c0;
}
};
priority_queue<item> pq;
int find(int x) {
return x == dsu[x] ? x : dsu[x] = find(dsu[x]);
}
signed main() {
ios::sync_with_stdio(false);
cin >> n;
for (register int i = 2; i <= n; i++)
cin >> fa[i];
for (register int i = 1, val; i <= n; i++)
cin >> val, cnt[i][val]++;
for (register int i = 1; i <= n; i++)
dsu[i] = i;
long long ans = 0;
for (register int i = 2; i <= n; i++)
pq.push({cnt[i][0], cnt[i][1], i});
while (!pq.empty()) {
item cur = pq.top(); pq.pop();
int x = find(cur.idx), c0 = cur.c0, c1 = cur.c1;
if (cnt[x][0] != c0 || cnt[x][1] != c1)
continue;
int y = find(fa[x]);
ans += cnt[y][1] * 1ll * cnt[x][0];
cnt[y][0] += cnt[x][0];
cnt[y][1] += cnt[x][1];
dsu[x] = y;
if (y > 1) pq.push({cnt[y][0], cnt[y][1], y});
}
cout << ans << endl;
return 0;
}