题意:给定一棵树(测试数据中树的结点数中告知),求断掉每条边两棵子树所有重心的编号和之和。
题解:重心 + 树状数组
先随便求个重心出来,记做(k)。子树大小记为(sz),子结点子树大小的最大值记为(mx),考虑一个结点(u ot = k):
首先,割掉一条边使得(u)成为所在连通块的重心,这条边一定不在(u)子树内。原因很简单,考虑若(u)在(k)的(v)方向的子树中,把(u)子树中一条边删去后,(sz[v])会减小,那么重心(k)会往非(v)方向偏移或者不动,但绝不会进入(v)子树。
然后再考虑删去的子树应满足什么条件。容易发现是否合法只与子树大小有关:我们得保证与(u)相连的子树都满足大小不超过连通块大小除以(2)这个条件。令(S)为删去的子树(与(u)不连通的子树)大小,条件形式化即为:
[2 imes mx[u]leq n-S\
2 imes(n-S-sz[u]) leq n-S
]
解出
[S in[n - 2 imes sz[u], n-2 imes mx[u]]
]
再补上条件,删掉的边不在(u)子树中。
这个怎么求呢?看起来可以无脑主席树,但是有更好的做法:
一条边((u,fa[u])),(sz[u])对(u)子树外有贡献,(n-sz[u])对(u)子树内有贡献,(sz[u])对(u)的祖先结点有负的贡献。
对于第一个,把所有的贡献加入树状数组,到进入(u)的时候把贡献去掉,离开(u)是把贡献加上。
对于第二个,进入(u)的时候把贡献加入,离开(u)是把贡献去掉。
对于第三个,记录进出(u)贡献的增量,把一、二的贡献减去这个增量,就是(u)作为重心的次数。
最后考虑(u=k)的情况:
设最大儿子是(v),次大儿子是(v'),则可以分两类:
-
若删掉的子树在(v)中,(S leq n - 2 imes sz[v'])
-
若删掉的子树不在(v)中,(Sleq n-2 imes sz[v])
dfs的时候记录一下即可。
代码皮了一下。
#include <algorithm>
#include <cstdio>
#include <vector>
#define pb push_back
using namespace std;
typedef long long ll;
const int N = 3e5 + 5;
int n, rt, sz[N], mx[N], son1, son2;
vector<int> G[N];
ll ans;
struct Bit {
int *bit, len;
void resize(int n) {
bit = new int[n + 5]; len = n;
for(int i = 0; i <= n; i ++) bit[i] = 0;
}
void clear() { delete bit; }
void add(int u, int v) {
for(; u <= len; u += u & (-u)) {
bit[u] += v;
}
}
int qry(int u) {
int ans = 0;
for(; u >= 1; u &= u - 1) {
ans += bit[u];
}
return ans;
}
int qry(int l, int r) {
return qry(r) - qry(l - 1);
}
} *b1 = new Bit, *b2 = new Bit;
void dfs(int u, int fa = 0) {
bool tg = 1; sz[u] = 1; mx[u] = 0;
for(int i = 0; i < (int) G[u].size(); i ++) {
int v = G[u][i];
if(v == fa) continue ;
dfs(v, u); sz[u] += sz[v];
if(sz[v] > n / 2) tg = 0;
mx[u] = max(mx[u], sz[v]);
}
if(n - sz[u] > n / 2) tg = 0;
if(tg && !rt) rt = u;
}
bool s1[N];
void dfs2(int u, int fa = 0) {
ll c = 0; s1[u] = s1[fa] || u == son1;
if(u != rt) {
b1->add(sz[u], -1); b1->add(n - sz[u], 1); b2->add(sz[u], 1);
c = b1->qry(n - 2 * sz[u], n - 2 * mx[u]);
c += b2->qry(n - 2 * sz[u], n - 2 * mx[u]);
if(s1[u] && sz[u] <= n - 2 * sz[son2]) ans += rt;
if(!s1[u] && sz[u] <= n - 2 * sz[son1]) ans += rt;
}
for(int i = 0; i < (int) G[u].size(); i ++) {
int v = G[u][i];
if(v == fa) continue ;
dfs2(v, u);
}
if(u != rt) {
b1->add(sz[u], 1); b1->add(n - sz[u], -1);
c -= b2->qry(n - 2 * sz[u], n - 2 * mx[u]);
ans += u * c;
}
}
int main() {
int test; scanf("%d", &test);
while(test --) {
scanf("%d", &n); b1->resize(n); b2->resize(n);
for(int i = 1; i <= n; i ++) G[i].clear();
for(int u, v, i = 1; i < n; i ++) {
scanf("%d%d", &u, &v);
G[u].pb(v); G[v].pb(u);
}
rt = 0; dfs(1); dfs(rt);
ans = 0;
for(int i = 1; i <= n; i ++) if(i != rt) {
b1->add(sz[i], 1);
}
son1 = son2 = 0;
for(int i = 0; i < (int) G[rt].size(); i ++) {
int v = G[rt][i];
if(sz[v] > sz[son1]) {
son2 = son1; son1 = v;
} else if(sz[v] > sz[son2]) {
son2 = v;
}
}
dfs2(rt);
printf("%lld
", ans);
b1->clear(); b2->clear();
}
return 0;
}