树链剖分
本质上,树链剖分是一种将树肢解成链平摊开来,再使用线段树对其进行维护的神奇算法。
我们需要通过两次 (dfs) ,预处理一些我们需要的东西,这里是第一次:
-
树上每个节点的父亲,这个不必多说,方便后续找 (LCA) 时的上跳过程;
-
树上每个节点的深度,这个也不必多说,方便后续决定对哪个节点进行操作;
-
每个节点的子树大小,同时标记每个节点的重儿子;
何为重儿子? 顾名思义,对于一个节点,他的所有儿子中子树大小最大的那个儿子就是重儿子。其他的儿子我们就称其为轻儿子。
对于一条树边,连接重儿子的边我们叫他重边,连接轻儿子的边我们叫他轻边(又称轻链),重边连成的链我们叫他重链。
这里是第二次:
-
每个节点的 (dfs) 序,将树肢解后节点就按这个顺序平摊,值得注意的是每次要先遍历该节点的重儿子,回溯之后再遍历其他出点;
-
(dfs) 编号所对应的节点编号,方便逆向访问到这个节点;
-
对于每各节点,我们记录顺着该节点所在的链向上走所能到达的最上方的节点(链顶);
形象地说,重链是我们修建的高速公路,在其上我们可以直接快速达到一条链的最顶部,而轻链则不同,在其上只能一步一步向上跳跃(也可以说,轻链的链顶就是当前点的父亲,轻链是长度为1
的重链)。
实际上,所谓轻链重链并没有什么特殊意义,其本质是将一棵树剖成几条链的一个较为方便的策略。
我们容易知道,树上任意两个点的最短路径都可以通过上文提到的轻重链来达到。对于每个链顶深度较大的点,我们让他跳跃到他链顶的父亲的位置,并对他途径的链进行区间操作。
这个时候,你就会发现将重链 (dfs) 序连续的妙处所在了:
每一个重链都处在一段连续的区间上,我们可以统一对其进行处理。
而对于轻链,由于轻链长度只有一,所以不在乎所处区间是否连续。
重复上述操作,直到两个点的链顶相同为止。通过这样的步骤,我们将树上两点间的最短路径拆分成了数个区间,然后转化为了区间操作。
那么对子树的操作如何实现呢?
不难发现,每一个节点的子树在区间上都连续,理由很简单,只有当前子树递归完毕之后才会访问另一颗子树。同时,我们也可以得到这个区间的左右端点,若设当前节点的 (dfs) 序为 (x):
(left node : x , right node : x + size[x] - 1)
然后,对这个区间进行区间操作即可。
怎么样,是不是觉得非常简单?
当然,在代码实现的过程中,依旧有一些小小的细节值得注意:
-
线段树,不用我说,写错了就拖出去枪毙十分钟(笔者至少被枪毙了半小时);
-
对于子树操作(等价于区间操作),参数是点的 (dfs) 序,而对于最短路操作,参数则是点原本的序号,务必要搞清楚 (dfs) 序与原本点的编号的异同;
-
先两次 (dfs) ,执行完毕后再建树(这问题太蠢我不忍直视);
-
(dfs1) 中注意不要又跑回父亲节点了,(dfs2) 中注意到了叶子节点要及时终止函数。
一时间就想到这么多,希望能对大家有所帮助。
以下提供模板代码,为了锻炼读者的代码阅读能力(我懒),没有加任何注释,愿各位食用愉快(逃)
模板代码
#include<iostream>
#include<cctype>
#include<cstdio>
using namespace std;
typedef long long ll;
const int maxn = 50005;
ll read(){
ll re = 0,ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) re = (re<<1) + (re<<3) + ch - '0',ch = getchar();
return re;
}
int n,m,r,p;
struct edge{
int v,nxt;
}e[maxn<<1];
int h[maxn],cnt;
void addedge(int u,int v){
e[++cnt].v = v;
e[cnt].nxt = h[u];
h[u] = cnt;
}
int fa[maxn],sz[maxn],son[maxn],dfn[maxn],rev[maxn],val[maxn],dis[maxn],top[maxn];
void dfs1(int u,int f){
dis[u] = dis[f] + 1;
fa[u] = f;
sz[u] = 1;
for(int i = h[u];i;i = e[i].nxt){
if(e[i].v != f){
dfs1(e[i].v,u);
sz[u] += sz[e[i].v];
if(sz[e[i].v] > sz[son[u]]) son[u] = e[i].v;
}
}
}
void dfs2(int u,int topf){
dfn[u] = ++cnt;
rev[cnt] = u;
top[u] = topf;
if(!son[u]) return;
dfs2(son[u],topf);
for(int i = h[u];i;i = e[i].nxt)
if(!dfn[e[i].v]) dfs2(e[i].v,e[i].v);
}
struct node{
int l,r;
ll sum,add;
#define l(x) t[x].l
#define r(x) t[x].r
#define sum(x) t[x].sum
#define add(x) t[x].add
#define mid(x) (t[x].r + t[x].l >> 1)
}t[maxn<<2];
void pushdown(int x){
if(add(x)){
sum(x<<1) += add(x) * (mid(x) - l(x) + 1);
sum(x<<1|1) += add(x) * (r(x) - mid(x));
add(x<<1) += add(x);
add(x<<1|1) += add(x);
add(x) = 0;
}
}
void pushup(int x){
sum(x) = (sum(x<<1) % p + sum(x<<1|1) % p) % p;
}
void build(int x,int l,int r){
l(x) = l;
r(x) = r;
if(l == r){
sum(x) = val[rev[l]];
return;
}
build(x<<1,l,mid(x));
build(x<<1|1,mid(x) + 1,r);
pushup(x);
}
void modify(int x,int l,int r,int v){
if(l <= l(x) && r >= r(x)){
sum(x) += (r(x) - l(x) + 1) * v;
add(x) += v;
sum(x) %= p;
add(x) %= p;
return;
}
pushdown(x);
if(l <= mid(x)) modify(x<<1,l,r,v);
if(r > mid(x)) modify(x<<1|1,l,r,v);
pushup(x);
}
ll quiry(int x,int l,int r){
ll ans = 0;
if(l <= l(x) && r >= r(x))
return sum(x);
pushdown(x);
if(l <= mid(x)) ans += quiry(x<<1,l,r);
if(r > mid(x)) ans += quiry(x<<1|1,l,r);
return ans % p;
}
void tadd(int x,int y,int v){
while(top[x] != top[y]){
if(dis[top[x]] < dis[top[y]]) swap(x,y);
modify(1,dfn[top[x]],dfn[x],v);
x = fa[top[x]];
}
if(dis[x] > dis[y]) swap(x,y);
modify(1,dfn[x],dfn[y],v);
}
ll task(int x,int y){
ll ans = 0;
while(top[x] != top[y]){
if(dis[top[x]] < dis[top[y]]) swap(x,y);
ans += quiry(1,dfn[top[x]],dfn[x]);
ans %= p;
x = fa[top[x]];
}
if(dis[x] > dis[y]) swap(x,y);
ans += quiry(1,dfn[x],dfn[y]);
return ans % p;
}
int main(){
n = read(),m = read(),r = read(),p = read();
for(int i = 1;i <= n;i++) val[i] = read();
for(int i = 1;i < n;i++){
int u = read(),v = read();
addedge(u,v);
addedge(v,u);
}
cnt = 0;
dfs1(r,0);
dfs2(r,0);
build(1,1,n);
for(int i = 1;i <= m;i++){
int op = read(),x,y,z;
if(op == 1){
x = read(),y = read(),z = read();
tadd(x,y,z);
}
if(op == 2){
x = read(),y = read();
printf("%lld
",task(x,y));
}
if(op == 3){
x = read(),z = read();
modify(1,dfn[x],dfn[x] + sz[x] - 1,z);
}
if(op == 4){
x = read();
printf("%lld
",quiry(1,dfn[x],dfn[x] + sz[x] - 1));
}
}
return 0;
}