经过了一系列的前置知识,终于学会了树链剖分!!
重链剖分的思想:
重链剖分可以将树上的任意一条路径划分成不超过(O(logn))条连续的链,每条链上的点深度互不相同(即是自底向上的一条链,链上所有点的(LCA)$为链的一个端点)。
重链剖分还能保证划分出的每条链上的节点(DFS)序连续,因此可以方便地用一些维护序列的数据结构(如线段树)来维护树上路径的信息。
如:
-
修改 树上两点之间的路径上 所有点的值。
-
查询 树上两点之间的路径上 节点权值的 和/极值/其它(在序列上可以用数据结构维护,便于合并的信息)。
我们给出一些定义:
-
重子节点 :表示其子节点中子树最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。
-
轻子节点 :表示剩余的所有子结点。
-
重边 :从这个结点到重子节点的边为重边 。
-
轻边 :到其他轻子节点的边为 轻边 。
-
重链 :若干条首尾衔接的重边构成 重链 。
实现:
树剖的实现分两个(DFS)的过程
第一个(DFS)记录每个结点的父节点、深度、子树大小、重子节点。
-
(siz_x),表示子树(x)的大小
-
(dep_x),表示点(x)的深度
-
(fa_x),表示点(x)的父亲
-
(son_x),表示点(x)的重儿子
void dfs1(int x){
siz[x] = 1;dep[x] = dep[fa[x]]+1;
for (int i = head[x];i;i = ed[i].nxt){
int to = ed[i].to;
if (to == fa[x]) continue;
fa[to] = x;
dfs1(to);
if (siz[to] > siz[son[x]]) son[x] = to;
siz[x] += siz[to];
}
}
第二个(DFS)记录所在链的链顶((root),应初始化为结点本身)、重边优先遍历时的(DFS)序((dfn))、(DFS)序对应的节点编号((pos))。
-
(str_x),表示(x)所在重链的链顶
-
(dfn_x),表示点(x)的(dfs)序
-
(pos_x),表示(dfs)序为(x)的点
void dfs2(int x,int root){
str[x] = root;
dfn[x] = ++cnt;pos[cnt] = x;
if (son[x]) dfs2(son[x],root);
for (int i = head[x];i;i = ed[i].nxt){
int to = ed[i].to;
if (to == fa[x]||to == son[x]) continue;
dfs2(to,to);
}
}
路径上修改和查询:
链上的(DFS)序是连续的,可以使用线段树、树状数组维护,每次选择深度较大的链往上跳,直到两点在同一条链上。
void fix1(){
int a = read(),b = read(),x = read();
while (str[a] != str[b]){
if (dep[str[a]] < dep[str[b]]) swap(a,b);
modify(1,1,n,dfn[str[a]],dfn[a],x);
a = fa[str[a]];
}
if (dep[a] > dep[b]) swap(a,b);
modify(1,1,n,dfn[a],dfn[b],x);
}
void fix2(){
int a = read(),b = read();
int res = 0;
while (str[a] != str[b]){
if (dep[str[a]] < dep[str[b]]) swap(a,b);
(res += query(1,1,n,dfn[str[a]],dfn[a]))%=mod;
a = fa[str[a]];
}
if (dep[a] > dep[b]) swap(a,b);
(res += query(1,1,n,dfn[a],dfn[b]))%=mod;
printf("%lld
",res);
}
子树修改和查询:
在(DFS)搜索的时候,子树中的结点的(DFS)序是连续的,每一个结点到子树末端的结点的(dfs)序就为他本身的(dfs)序+子树大小-1。
void fix3(){
int a = read(),x = read();
modify(1,1,n,dfn[a],dfn[a]+siz[a]-1,x);
}
void fix4(){
int a = read();
printf("%lld
",query(1,1,n,dfn[a],dfn[a]+siz[a]-1));
}
例题的完整代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#define int long long
using namespace std;
int read(){
int x = 1,a = 0;char ch = getchar();
while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
return x*a;
}
const int maxn = 1e5+10;
int n,m,r,mod,a[maxn];
struct node{
int to,nxt;
}ed[maxn*2];
int head[maxn*2],tot;
void add(int u,int to){
ed[++tot].to = to;
ed[tot].nxt = head[u];
head[u] = tot;
}
int fa[maxn],siz[maxn],son[maxn],dep[maxn];
void dfs1(int x){
siz[x] = 1;dep[x] = dep[fa[x]]+1;
for (int i = head[x];i;i = ed[i].nxt){
int to = ed[i].to;
if (to == fa[x]) continue;
fa[to] = x;
dfs1(to);
if (siz[to] > siz[son[x]]) son[x] = to;
siz[x] += siz[to];
}
}
int cnt,str[maxn],dfn[maxn],pos[maxn];
void dfs2(int x,int root){
str[x] = root;
dfn[x] = ++cnt;pos[cnt] = x;
if (son[x]) dfs2(son[x],root);
for (int i = head[x];i;i = ed[i].nxt){
int to = ed[i].to;
if (to == fa[x]||to == son[x]) continue;
dfs2(to,to);
}
}
int tree[maxn*4],lazy[maxn*4];
int ls(int x){return x<<1;}
int rs(int x){return x<<1|1;}
void pushup(int x){
tree[x] = tree[ls(x)] + tree[rs(x)];
}
void build(int x,int l,int r){
if (l == r){tree[x] = a[pos[l]];return;}
int mid = (l+r)>>1;
build(ls(x),l,mid);build(rs(x),mid+1,r);
pushup(x);
}
void tag(int x,int l,int r,int k){
lazy[x] += k;
tree[x] += (r-l+1)*k;
}
void pushdown(int x,int l,int r){
int mid = (l+r)>>1;
tag(ls(x),l,mid,lazy[x]);
tag(rs(x),mid+1,r,lazy[x]);
lazy[x] = 0;
}
void modify(int x,int l,int r,int nl,int nr,int k){
if (nl <= l&&r <= nr){tag(x,l,r,k);return;}
int mid = (l+r)>>1;
pushdown(x,l,r);
if (nl <= mid) modify(ls(x),l,mid,nl,nr,k);
if (nr > mid) modify(rs(x),mid+1,r,nl,nr,k);
pushup(x);
}
int query(int x,int l,int r,int nl,int nr){
int res = 0;
if (nl <= l&&r <= nr) return tree[x];
int mid = (l+r)>>1;
pushdown(x,l,r);
if (nl <= mid) (res+=query(ls(x),l,mid,nl,nr))%=mod;
if (nr > mid) (res+=query(rs(x),mid+1,r,nl,nr))%=mod;
return res;
}
void fix1(){
int a = read(),b = read(),x = read();
while (str[a] != str[b]){
if (dep[str[a]] < dep[str[b]]) swap(a,b);
modify(1,1,n,dfn[str[a]],dfn[a],x);
a = fa[str[a]];
}
if (dep[a] > dep[b]) swap(a,b);
modify(1,1,n,dfn[a],dfn[b],x);
}
void fix2(){
int a = read(),b = read();
int res = 0;
while (str[a] != str[b]){
if (dep[str[a]] < dep[str[b]]) swap(a,b);
(res += query(1,1,n,dfn[str[a]],dfn[a]))%=mod;
a = fa[str[a]];
}
if (dep[a] > dep[b]) swap(a,b);
(res += query(1,1,n,dfn[a],dfn[b]))%=mod;
printf("%lld
",res);
}
void fix3(){
int a = read(),x = read();
modify(1,1,n,dfn[a],dfn[a]+siz[a]-1,x);
}
void fix4(){
int a = read();
printf("%lld
",query(1,1,n,dfn[a],dfn[a]+siz[a]-1));
}
signed main(){
n = read(),m = read(),r = read(),mod = read();
for (int i = 1;i <= n;i++) a[i] = read();
for (int i = 1;i <= n-1;i++){
int x = read(),y = read();
add(x,y),add(y,x);
}
dfs1(r);dfs2(r,r);
build(1,1,n);
for (int i = 1;i <= m;i++){
int op = read();
if (op == 1) fix1();
if (op == 2) fix2();
if (op == 3) fix3();
if (op == 4) fix4();
}
return 0;
}