链覆盖,询问子树最小值, 换根。
lct 即可。
可以使用树剖, 询问的时候分类讨论。
- id 和 root 相等, 直接询问即可。
- root ≠ lca(id, root) ≠ id, 直接询问即可
- id 是 root 的祖先, 相当于询问 id 除 root 所在的子树之外的位置, 在 dfs 序上分解即可
- root 是 id 的祖先, 直接询问即可
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 3, inf = 2147483647;
int root;
int n, m, val[N], a[N];
int ecnt, hd[N], nt[N*2+1], vr[N*2+1];
void add(int x, int y) {nt[++ecnt]=hd[x],hd[x]=ecnt; vr[ecnt]=y;}
int dep[N], siz[N], fa[N], tp[N], son[N];
void dfs1(int x, int F, int D) {
dep[x] = D, siz[x] = 1, fa[x] = F;
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(y == F) continue;
dfs1(y, x, D + 1);
siz[x] += siz[y];
if(siz[y] > siz[son[x]]) son[x] = y;
}
}
int dfntot, in[N], out[N];
void dfs2(int x, int T) {
in[x] = ++dfntot; a[in[x]] = val[x];
tp[x] = T;
if(son[x]) dfs2(son[x], T);
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
out[x] = dfntot;
}
int up(int x, int D, int Top) {
while(dep[tp[x]] > D) x = fa[tp[x]];
return dep[tp[x]] == D ? tp[x] : son[Top];
}
int t[N<<2], tag[N<<2];
#define li (me << 1)
#define ri (li | 1)
#define mid ((l+r)>>1)
#define ls li,l,mid
#define rs ri,mid+1,r
void build(int me, int l, int r) {
if(l == r) { t[me] = a[l]; return; }
build(ls), build(rs), t[me] = min(t[li], t[ri]);
}
void ad(int me, int v) { t[me] = tag[me] = v;}
void pushdown(int me) {
if(tag[me]) {ad(li, tag[me]), ad(ri, tag[me]), tag[me] = 0; }
}
void modi(int me, int l, int r, int x, int y, int v) {
if(x <= l && r <= y) { ad(me, v); return;}
pushdown(me);
if(x<=mid) modi(ls, x, y, v);
if(y >mid) modi(rs, x, y, v);
t[me] = min(t[li], t[ri]);
}
int ask(int me, int l, int r, int x, int y) {
if(x <= l && r <= y) return t[me];
pushdown(me);
int ret = inf;
if(x<=mid) ret = ask(ls, x, y);
if(y >mid) ret = min(ret, ask(rs, x, y));
return ret;
}
void road_modi(int x, int y, int v) {
while(tp[x] != tp[y]) {
if(dep[tp[y]] > dep[tp[x]]) swap(x, y);
modi(1, 1, n, in[tp[x]], in[x], v);
x = fa[tp[x]];
}
if(dep[y] > dep[x]) swap(x, y);
modi(1, 1, n, in[y], in[x], v);
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i < n; ++i) {
int x, y; scanf("%d%d", &x, &y); add(x, y), add(y, x);
}
for(int i = 1; i <= n; ++i) scanf("%d", &val[i]);
scanf("%d", &root);
dfs1(1, 0, 1), dfs2(1, 1);
build(1, 1, n);
int opt, id, x, y, v;
while(m--)
{
scanf("%d", &opt);
if(opt == 1)
{
scanf("%d", &id);
root = id;
}
else
if(opt == 2)
{
scanf("%d%d%d", &x, &y, &v);
road_modi(x, y, v);
}
else
if(opt == 3)
{
scanf("%d", &id);
if(id == root) cout << t[1] << '
';
else
{
if(in[id] <= in[root] && out[root] <= out[id])
{
int tmp = up(root, dep[id] + 1, id);
int l = in[tmp] - 1, r = out[tmp] + 1;
int ret = inf;
if(1 <= l) ret = ask(1, 1, n, 1, l);
if(r <= n) ret = min(ret, ask(1, 1, n, r, n));
cout << ret << '
';
}
else
cout << ask(1, 1, n, in[id], out[id]) << '
';
}
}
}
return 0;
}
对于一颗 n 个点的树, 给定 m 个关键点, 可以构造一棵点数为 O(m) 的保持原树祖先关系的树, 称为虚树。为了方便,通常会把原树的根默认为关键点。
预处理出原树的 dfs 序, 这样,按这个 dfs 序访问每组关键点, 相当于在虚树上 dfs, dfs 的过程中连边并添加辅助点即可。
设 dp[u] 为割掉 u 子树内所有关键点的最小代价, 若 u 为关键点, 则 dp[u] 为 u 到根的最小边权,反之, dp[u] = min(u 到根的最小边权,Σv∈u.son dp[v])。
发现 dp 子树内没有关键点的点是没有意义的, 进而, 对于子树内关键点集合相同的若干点, 有很多都是无用的, 于是在虚树上 DP, 答案不变。
// 这段代码是优化过的, 不是原始的虚树 DP。具体地, 建出的虚树的关键点都不记录其子树。
#include<bits/stdc++.h>
typedef long long LL;
using namespace std;
int read() {
int x = 0; char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x;
}
const int N = 250003;
const LL inf = 1e18;
int n;
int ecnt, hd[N], nt[N*2+1], vr[N*2+1], w[N*2+1];
void ad(int u, int v, int w_) {nt[++ecnt] = hd[u], hd[u] = ecnt, vr[ecnt] = v, w[ecnt] = w_; }
LL mi[N];
int dep[N], siz[N], fa[N], son[N], tp[N];
void dfs1(int x, int F, int D) {
dep[x] = D, siz[x] = 1, fa[x] = F;
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(y == F) continue;
mi[y] = min(mi[x], (LL)w[i]);
dfs1(y, x, D + 1);
siz[x] += siz[y];
if(siz[y] > siz[son[x]]) son[x] = y;
}
}
int dfntot, dfn[N];
void dfs2(int x, int T) {
dfn[x] = ++dfntot;
tp[x] = T;
if(son[x]) dfs2(son[x], T);
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
int lca(int x, int y) {
while(tp[x] != tp[y]) dep[tp[x]] > dep[tp[y]] ? x = fa[tp[x]] : y = fa[tp[y]];
return dep[x] > dep[y] ? y : x;
}
int k, h[N];
bool cmp(int s1, int s2) { return dfn[s1] < dfn[s2]; }
vector<int> v[N];
int s[N], t;
void extnd(int x) {
if(t == 1) { s[++t] = x; return; }
int l = lca(s[t], x);
if(l == s[t]) { return; }
while(t > 1 && dfn[s[t - 1]] >= dfn[l]) v[s[t - 1]].push_back(s[t]), --t;
if(l != s[t]) v[l].push_back(s[t]), s[t] = l;
s[++t] = x;
}
LL dp(int x) {
if(v[x].size() == 0) return mi[x];
LL tmp = 0ll;
for(int i = 0; i < (int)v[x].size(); ++i) tmp += dp(v[x][i]);
v[x].clear();
return min(mi[x], tmp);
}
int main()
{
n = read();
for(int i = 1; i < n; ++i) {
int x = read(), y = read(), z = read(); ad(x, y, z), ad(y, x, z);
}
mi[1] = inf;
dfs1(1, 0, 1), dfs2(1, 1);
int m = read();
while(m--)
{
k = read();
for(int i = 1; i <= k; ++i) h[i] = read();
sort(h + 1, h + 1 + k, cmp);
s[t = 1] = 1;
for(int i = 1; i <= k; ++i) extnd(h[i]);
while(t > 0) v[s[t - 1]].push_back(s[t]), --t;
printf("%lld
", dp(1));
}
return 0;
}
定理:对于任何高度为 O(log n) 的树, Σu siz[u] = O(n log n)。
点分治时, 用当前层的中心当作下一层所有中心的父亲, 就可以得到一棵点分治重构树。
可根据上面的定理可以证明点分治的复杂度。
梳理下淀粉质板子。
//淀粉质的主过程,即构建淀粉树的过程
void calc(int sta) {
//第一步得到当前淀粉树子树的 size
get_mid(sta, 0);
//而后找出重心
sumsize = siz[sta];
get_mid(sta, 0);
vis[mid] = true;
//从这开始计算分治中心
//从这结束计算分治中心
//分治
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(vis[y]) continue;
calc(y);
}
}
可以直接点分治, 像 DSU on Tree 那样用 dfn 会爽许多。
由于 k 较小, 统计的时候可以用数据结构搞, 而不必用双指针+容斥。
我会将两种写法都写一遍。
// 数据结构统计, 缺点是难以应付 k 较大的情况
#include<bits/stdc++.h>
typedef long long LL;
using namespace std;
const int N = 4e4 + 233, K = 2e4 + 233;
int n, k;
int ecnt, hd[N], nt[N*2+1], vr[N*2+1], we[N*2+1];
void ad(int u, int v, int w) {nt[++ecnt] = hd[u], hd[u] = ecnt; vr[ecnt] = v, we[ecnt] = w; }
LL ans = 0ll;
bool vis[N];
int siz[N];
int sumsiz, mid;
void get_mid(int x, int F) { int mx = 0;
siz[x] = 1;
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(y == F || vis[y]) continue;
get_mid(y, x);
siz[x] += siz[y];
mx = max(mx, siz[y]);
}
mx = max(mx, sumsiz - siz[x]);
if(mx <= (sumsiz / 2)) mid = x;
}
int dis[N];
int dfntot, in[N], out[N], row[N];
void dfs(int x, int F) {
in[x] = ++dfntot; row[dfntot] = x;
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(y == F || vis[y]) continue;
dis[y] = dis[x] + we[i];
dfs(y, x);
}
out[x] = dfntot;
}
int t[K];
void ins(int x) {
if(!x) return;
for(; x <= k; x += (x & (-x))) ++t[x];
}
int ask(int x) {
int res = 0;
for(; x; x -= (x & (-x))) res += t[x];
return res;
}
void fuck(int x) {
dfntot = 0;
dis[x] = 0;
dfs(x, 0);
memset(t, 0, sizeof t);
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(vis[y]) continue;
for(int j = in[y]; j <= out[y]; ++j) {
int Y = row[j];
if(dis[Y] <= k) ++ans;
if(dis[Y] < k) ans += ask(k - dis[Y]);
}
for(int j = in[y]; j <= out[y]; ++j) {
int Y = row[j];
if(dis[Y] < k) ins(dis[Y]);
}
}
}
void calc(int sta) {
get_mid(sta, 0);
sumsiz = siz[sta];
get_mid(sta, 0);
vis[mid] = true;
if(siz[mid] > 1) fuck(mid);
for(int i = hd[mid]; i; i = nt[i]) {
int y = vr[i];
if(vis[y]) continue;
calc(y);
}
}
int main()
{
scanf("%d", &n);
for(int i = 1; i < n; ++i) {
int u, v, w; scanf("%d%d%d", &u, &v, &w); ad(u, v, w), ad(v, u, w);
}
scanf("%d", &k);
calc(1);
cout << ans;
return 0;
}
// 双指针+容斥, 有复杂度优势, 常数不足为惧
#include <bits/stdc++.h>
typedef long long LL;
using namespace std;
const int N = 4e4 + 23;
int n, k;
int ecnt, hd[N], nt[N*2+1], vr[N*2+1], we[N*2+1];
void ad(int u, int v, int w) { nt[++ecnt] = hd[u], hd[u] = ecnt; vr[ecnt] = v, we[ecnt] = w; }
LL ans = 0ll;
bool vis[N];
int siz[N];
int sumsiz, mid;
void get_mid(int x, int F) {int mx = 0;
siz[x] = 1;
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(y == F || vis[y]) continue;
get_mid(y, x);
siz[x] += siz[y];
mx = max(mx, siz[y]);
}
mx = max(mx, sumsiz - siz[x]);
if(mx <= (sumsiz / 2)) mid = x;
}
int dis[N];
int dfntot, in[N], out[N], row[N];
void dfs(int x, int F) {
in[x] = ++dfntot; row[dfntot] = x;
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(y == F || vis[y]) continue;
dis[y] = dis[x] + we[i];
dfs(y, x);
}
out[x] = dfntot;
}
int s[N], t;
LL gao() {
LL res = 0ll;
sort(s+1, s+1+t);
int l = 1, r = t;
while(l < r && l <= t) {
while(r >= 1 && s[r] + s[l] > k) --r;
if(l < r) res += (r - l);
++l;
}
return res;
}
void fuck(int x) {
dfntot = 0;
dis[x] = 0;
dfs(x, 0);
t = 0;
for(int i = in[x]; i <= out[x]; ++i) s[++t] = dis[row[i]];
ans += gao();
for(int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if(vis[y]) continue;
t = 0;
for(int j = in[y]; j <= out[y]; ++j) {
s[++t] = dis[row[j]];
}
ans -= gao();
}
}
void calc(int sta) {
get_mid(sta, 0);
sumsiz = siz[sta];
get_mid(sta, 0);
vis[mid] = true;
if(siz[mid] > 1) fuck(mid);
for(int i = hd[mid]; i; i = nt[i]) {
int y = vr[i];
if(vis[y]) continue;
calc(y);
}
}
int main() {
scanf("%d", &n);
for(int i = 1; i < n; ++i) {
int u, v, w; scanf("%d%d%d", &u, &v, &w); ad(u, v, w), ad(v, u, w);
}
scanf("%d", &k);
calc(1);
cout << ans;
return 0;
}