Description
给定一颗 (n) 个点的树,点集为 (V)。(forall i in {1, 2, cdots, n}),求:
- 对于一个 (Ssubseteq V),将 (S) 中的点点权设为 (1),其他的点为 (0)。设 (f(S)) 为带权重心的个数。
- 求 (maxlimits_{|S|=i} f(S))。
Constraints
(1le nle 2cdot 10^5)。
Solution
我菜死了,什么点分治倍增我完全不会,那么只好乱胡了一个线段树合并。发现我这种方法好像有点少,那就随便写一个了。
首先是一个必要的转化:
- 如果 (i) 为奇数答案必然为 (1);
- 如果 (i) 为偶数,那么可以发现,如果有两个子树中带权的点数正好都为 ( frac i 2),那么两个子树根连成的路径(包括两端)都可以作为重心。而要答案尽可能大的话,那么就在这个无根树上找出两个 (ge frac i 2) 的子树即可,且两个根尽可能远。
- 为了方便,考虑直接用“(= frac i 2)”取代“(ge frac i 2)”,最后对答案数组做一遍后缀 (max) 即可。
然后就是比较传统的路径问题,只不过这里的子树大小不是对于当前分治区域而言的,而是从整棵树来看。然后大概是什么有根树点分治,不太会就抛掉了。
考虑树上问题的另一个经典思路:线段树合并。为了实现转化题意中两个 size 配对的过程(这种说法并不准确,下文会提到),我们的线段树以 size 为下标,以最大可能的深度为对应值。
在当前点 (x),合并两个子树的线段树时,如果合并到叶子就可以更新答案了。由于两点间距离的计算还需要 LCA 的深度,我们在合并是顺便带上 ({dep}_x)。但这样存在一个问题,就是选取一个子树中的若干个点,但却不一定选取整个子树——相当于我们需要用这个子树的 (dep_x) 去更新线段树上 ([1, {siz}_x]) 的每一个位置。解决方案:利用动态开点,打上区间更新的标记,并在下传时新建节点。
这样可以实现子树间的信息合并,但还不够:父亲方向上的我们尚未考虑。容易发现,上方的点可以选择 ([1, n-{siz}_x]) 个,不妨设每种选择方式距离 (x) 最近的点为 (fa_x),如果实际上可以更远那会在上面更新。现在对于位置 ([1, n-{siz}_x]) 我们在打一个上下标记,值为 (dep_x -1)。同样在叶子节点上统计对答案的贡献。
注意 pushdown 时应先下传区间更新标记,再下传上下标记,dfs 是也用这样的顺序。pushdown 内如果下传到叶子也要更新答案。最后一棵总线段树需要把标记传干净。其他细节在写的时候就能发现了。复杂度 (O(nlog n))。
Code
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>
const int N = 2e5 + 5;
const int K = 50;
const int inf = 0x3f3f3f3f;
std::vector<int> adj[N];
int n, ans[N * 2];
int siz[N], dep[N];
void init(int x, int f) {
siz[x] = 1, dep[x] = dep[f] + 1;
for (int i = 0; i < (int)adj[x].size(); i++) {
int y = adj[x][i];
if (y == f) continue;
init(y, x), siz[x] += siz[y];
}
}
#define mid ((l + r) >> 1)
int lc[N * K], rc[N * K], tot;
int maxd[N * K], tag[N * K], fill[N * K];
inline void pushdown(int x, int l, int r) {
if (fill[x]) {
if (!lc[x]) lc[x] = ++tot;
maxd[lc[x]] = std::max(maxd[lc[x]], fill[x]);
fill[lc[x]] = std::max(fill[lc[x]], fill[x]);
if (!rc[x]) rc[x] = ++tot;
maxd[rc[x]] = std::max(maxd[rc[x]], fill[x]);
fill[rc[x]] = std::max(fill[rc[x]], fill[x]);
fill[x] = 0;
}
if (tag[x] != inf) {
if (lc[x]) {
tag[lc[x]] = std::min(tag[x], tag[lc[x]]);
if (r - l == 1)
ans[l * 2] = std::max(ans[l * 2], tag[lc[x]] - maxd[lc[x]] + 1);
}
if (rc[x]) {
tag[rc[x]] = std::min(tag[x], tag[rc[x]]);
if (r - l == 1)
ans[r * 2] = std::max(ans[r * 2], tag[rc[x]] - maxd[rc[x]] + 1);
}
tag[x] = inf;
}
}
void insert(int& x, int l, int r, int ql, int qr, int d) {
if (ql > r || l > qr) return;
if (!x) x = ++tot;
if (ql <= l && r <= qr) {
fill[x] = std::max(fill[x], d);
maxd[x] = std::max(maxd[x], d);
return;
}
pushdown(x, l, r);
insert(lc[x], l, mid, ql, qr, d);
insert(rc[x], mid + 1, r, ql, qr, d);
maxd[x] = std::max(maxd[lc[x]], maxd[rc[x]]);
}
int merge(int x, int y, int l, int r, int cur) {
if (!x || !y) return x | y;
if (l == r) {
ans[l * 2] = std::max(ans[l * 2], maxd[x] + maxd[y] - cur * 2 + 1);
maxd[x] = std::max(maxd[x], maxd[y]);
return x;
}
pushdown(x, l, r);
pushdown(y, l, r);
lc[x] = merge(lc[x], lc[y], l, mid, cur);
rc[x] = merge(rc[x], rc[y], mid + 1, r, cur);
maxd[x] = std::max(maxd[lc[x]], maxd[rc[x]]);
return x;
}
int at(int x, int l, int r, int ql, int qr) {
if (ql > r || l > qr || !x) return 0;
if (ql <= l && r <= qr) return maxd[x];
pushdown(x, l, r);
return std::max(at(lc[x], l, mid, l, r), at(rc[x], mid + 1, r, l, r));
}
void update(int x, int l, int r, int ql, int qr, int dtop) {
if (!x || ql > r || l > qr) return;
if (ql <= l && r <= qr) {
tag[x] = std::min(tag[x], dtop);
if (l == r) ans[l * 2] = std::max(ans[l * 2], tag[x] - maxd[x] + 1);
return;
}
pushdown(x, l, r);
update(lc[x], l, mid, ql, qr, dtop);
update(rc[x], mid + 1, r, ql, qr, dtop);
}
void flush(int x, int l, int r) {
if (!x) return;
if (l == r) {
ans[l * 2] = std::max(ans[l * 2], maxd[x] - tag[x] + 1);
return;
}
pushdown(x, l, r);
flush(lc[x], l, mid);
flush(rc[x], mid + 1, r);
}
#undef mid
int solve(int x, int f) {
int root = 0;
for (int i = 0; i < (int)adj[x].size(); i++) {
int y = adj[x][i];
if (y == f) continue;
int sub = solve(y, x);
root = merge(root, sub, 1, n, dep[x]);
}
insert(root, 1, n, 1, siz[x], dep[x]);
if (x != 1)
update(root, 1, n, 1, n - siz[x], dep[x] - 1);
return root;
}
signed main() {
memset(tag, inf, sizeof(tag));
scanf("%d", &n);
for (int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
adj[u].push_back(v);
adj[v].push_back(u);
}
init(1, 0);
int root = solve(1, 0);
flush(root, 1, n);
for (int i = n; i; --i) {
ans[i] = std::max(ans[i], 1);
if (!(i & 1)) ans[i] = std::max(ans[i], ans[i + 2]);
}
for (int i = 1; i <= n; i++)
printf("%d
", ans[i]);
return 0;
}