树的统计
题目链接:ybt金牌导航5-1-1 / luogu P2590
题目大意
对于一个树要支持一些操作。
修改一个点的权值,询问一个点到另一个点路径的节点权值和或者权值最大值。
思路
这道题是树链剖分的模板。
它主要就是对于一个点的儿子,我们把它分成两种:重儿子和轻儿子。
重儿子只有一个,就是它的子树大小比其他儿子的子树大小都大。
然后连向它的就是重边,连向轻儿子的就是轻边。
然后定义重链就是一条都是由重边组成的链,一个点也算重链。
然后你会想到,那一个点肯定是只在一个重链上,因为在多个重链,就说明它不止一个儿子是重儿子,就矛盾了嘛。
然后你还会想到,对于一条树上的路径,它肯定是这样的:重链,轻边,重链,轻边,……,重链。这样交错的出现,是不会有两个轻边在一起的,因为如果在一起,中间的点也会被算作重链,那就不在一起了。
然后你可以通过跑图维护一些值。
第一次跑图,我们可以维护基本信息:点的父亲,点的重儿子,点对应子树的大小,点在树中的深度。
第二次跑图,我们可以维护有关重边的信息,就可以维护这个点所在重链的顶点(深度最小的点)。
怎么维护呢,现在你到了一个点,如果你走向重儿子,那重儿子的重链所在定点就是这个点的重链所在定点。如果走向轻儿子,那它的重链就是它自己。
然后你会发现,你可以把重链都拿出来,放在一起形成一个数组。按的顺序我们可以从深度小的到深度大的顺序,也就是 dfs 的顺序。
那你可以这样:对于一条路径,它会被分为一些重链和一些轻边,那这些重链已经涵盖了路径上所有的点,那我们就对于所有重链搞一个线段树,然后维护它们的值,然后你要的时候就相当于你要这个重链的一个区间,那就是区间查询。
那线段树上的点会对应你原来树上的某个点,那我们就弄一个数组来表示对于的点,当然也可以弄一个互逆的数组记录树上的点对应线段树上的哪个点。
那这些重链的线段树是可以和在一起的,翻转你查询的时候只要自己控制不跑到另一个重链就可以了。
接着问题就是如何找到路径上的每个重链。
那这时候我们之前维护的“某个点所在重链的顶点”这个数组就有用了。
我们可以每次就两个点所在的重链选深度大的处理。(好让两个点一起被提起来,因为交点是 LCA,跑多了的部分就不属于这个路径)
那你要区间查询的就是你现在要上提的点到这个点所在重链的顶点之间的这一段重链。
然后这样一直搞,会使得两个点最后在同一个重链中。
那我们就区间查询深度小的到深度大的。(因为你深度小的是在前面,深度大的在线段树的后面,这样才能使得线段树的区间查询正常运行)
(前面的区间查询也是,查询顶点到当前点,位置不能反)
然后这道题目要我们维护区间最大值和区间和,那就在线段树里面按着线段树的样子正常写就可以了。
然后点的修改其实就直接修改它对应的线段树的点就可以了,也是线段树的普通修改。
代码
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
struct node {
int to, nxt;
}e[60001];
int n, x, y, le[30001], KK, tot, q;
int number[30001], top[30001], tree_pl[30001], normal_pl[120001];
int fa[30001], num[30001], deg[30001], son[30001];
int Max[120001], sum[120001], maxn, Sum, size[30001];
string c;
void add(int x, int y) {
e[++KK] = (node){y, le[x]}; le[x] = KK;
}
void dfs1(int now, int father) {//第一个跑图记录父亲,深度,子树大小,重儿子
fa[now] = father;
deg[now] = deg[father] + 1;
size[now] = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (father != e[i].to) {
dfs1(e[i].to, now);
size[now] += size[e[i].to];
if (size[e[i].to] > size[son[now]])
son[now] = e[i].to;
}
}
void dfs2(int now, int father) {
//第二个跑图记录点在线段树上的位置以及线段树某个位置的点在这里是哪个点
//也就是说记录的这两个东西是互逆的
//还记录这个点属于哪个重链(记录的是这个链深度最小的点)
if (son[now]) {//先弄重儿子
tree_pl[son[now]] = ++tot;
top[son[now]] = top[now];
normal_pl[tree_pl[son[now]]] = son[now];
dfs2(son[now], now);
}
for (int i = le[now]; i; i = e[i].nxt)//弄其它儿子(轻儿子)
if (e[i].to != father && e[i].to != son[now]) {//轻儿子不是父亲,也不是重儿子
tree_pl[e[i].to] = ++tot;
top[e[i].to] = e[i].to;
normal_pl[tree_pl[e[i].to]] = e[i].to;
dfs2(e[i].to, now);
}
}
void up(int now) {//线段树的向上传递值
Max[now] = max(Max[now << 1], Max[now << 1 | 1]);
sum[now] = sum[now << 1] + sum[now << 1 | 1];
}
void build(int now, int l, int r) {//按照给出的原来权值构造出线段树
if (l == r) {
Max[now] = sum[now] = number[normal_pl[l]];
return ;
}
int mid = (l + r) >> 1;
build(now << 1, l, mid);
build(now << 1 | 1, mid + 1, r);
up(now);
}
void change(int now, int l, int r, int lr, int change_num) {//单点修改
if (l > lr || r < lr) return ;
if (l == r) {
Max[now] = sum[now] = change_num;
return ;
}
int mid = (l + r) >> 1;
if (lr <= mid) change(now << 1, l, mid, lr, change_num);
else change(now << 1 | 1, mid + 1, r, lr, change_num);
up(now);
}
void query(int now, int l, int r, int L, int R) {//区间求和+区间最大值
if (r < L || l > R) return ;
if (l >= L && r <= R) {
maxn = max(maxn, Max[now]);
Sum += sum[now];
return ;
}
int mid = (l + r) >> 1;
if (L <= mid) query(now << 1, l, mid, L, R);
if (mid + 1 <= R) query(now << 1 | 1, mid + 1, r, L, R);
}
void ask(int x, int y) {//把路径划分成很多个重链,然后每个都线段树
while (top[x] != top[y]) {
if (deg[top[x]] < deg[top[y]]) {//跳深度大的,好让两边都一起被提上来
swap(x, y);
}
query(1, 1, tot, tree_pl[top[x]], tree_pl[x]);
x = fa[top[x]];
}
//已经在同一条重链上
if (deg[x] > deg[y]) swap(x, y);//这里就把顺序排好,从前面到后面,让线段树能正常跑
query(1, 1, tot, tree_pl[x], tree_pl[y]);
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; i++) {
scanf("%d %d", &x, &y);
add(x, y);
add(y, x);
}
for (int i = 1; i <= n; i++) scanf("%d", &number[i]);
//跑两次图
dfs1(1, 0);
tot = 1;//根节点的值可以直接预处理出来
top[1] = 1;
tree_pl[1] = 1;
normal_pl[1] = 1;
dfs2(1, 0);
build(1, 1, tot);//建树
scanf("%d", &q);
for (int i = 1; i <= q; i++) {
cin >> c;
if (c[0] == 'C') {
scanf("%d %d", &x, &y);
change(1, 1, tot, tree_pl[x], y);//修改
}
else if (c[1] == 'M') {
scanf("%d %d", &x, &y);
maxn = -1e9;
Sum = 0;
ask(x, y);
printf("%d
", maxn);//输出求出的最大值
}
else {
scanf("%d %d", &x, &y);
maxn = -1e9;
Sum = 0;
ask(x, y);
printf("%d
", Sum);//输出取出的和
}
}
return 0;
}