树链剖分,顾名思义,就是把树分成链。
通过这个方法,可以优化对树上两点间路径、某一点子树的修改和查询的操作,等。
流程
$dfs1()$
在这个函数中,要处理出每个节点的:
- 深度dep[]
- 父亲fa[]
- 大小siz[]
- 重儿子编号hson[]
一个节点的siz[],是包括它自己、它的儿子、它儿子的儿子……一共的节点数量。
所谓的重儿子,就是一个节点的儿子中,siz[]最大的那一个。
叶子节点没有儿子,所以也没有重儿子。
这个函数就是普通的遍历整棵树,每到一个点记录dpth[],siz[]初始值为1。
对于所连的每一条边,标记儿子的fa[]并递归$dfs$,回溯时将子树大小加给父亲的siz[],并判断是否更新重儿子hson[]。
void dfs1(int u){ dpth[u] = dpth[fa[u]]+1; siz[u] = 1; for(int i = head[u];i;i = nxt[i]){ int v = to[i]; if(v == fa[u])continue; fa[v] = u; dfs1(v); siz[u] += siz[v]; if(siz[v] > siz[hson[u]]) hson[u] = v; } }
知道了重儿子,如果沿着树走一遍,并且每次都先走重儿子,就可以把整个树拆成从大到小的很多链!
$dfs2()$
在这个函数中,要处理出每个节点的:
- 新编号,也就是先走重儿子的$dfs$序dfn[]
- 新编号对应的原编号point[]
- 所在链的顶端的点top[]
每到一个点,首先直接记录它的dfn[]、point[]、top[](有点类似father)。
如果没有重儿子,则是叶子节点,直接$return$;
如果有重儿子,则这条链还没有结束,继续递归,top[]不变;
如果除了重儿子还有其他的儿子,则每个轻儿子都是一条新链,继续递归,且新链的top[]是这个轻儿子。
void dfs2(int u,int t){ dfn[u] = ++cnt; point[cnt] = u; top[u] = t; if(!hson[u])return; dfs2(hson[u],t); for(int i = head[u];i;i = nxt[i]){ int v = to[i]; if(v == fa[u] || v == hson[u])continue; dfs2(v,v); } }
这样,“链”的部分就算是处理好了$-w-$
接下来,根据新的编号dfn[],建一棵线段树。
线段树需要的函数:$build$、$pushdown$、$modify$、$query$
下面来看需要支持的操作:
- 将树从x到y结点最短路径上所有节点的值都加上z
- 求树从x到y结点最短路径上所有节点的值之和
- 将以x为根节点的子树内所有节点值都加上z
- 求以x为根节点的子树内所有节点值之和
树上路径
求树上路径,听起来很像$LCA$……实际上,树链剖分的确可以求出$LCA$。
需要再写两个函数:$getmodify$,$getquery$。这里以查询操作为例。
如果x,y的top[]不同,则它们一定不在同一条链上。
和倍增$LCA$的做法相似,每次将top[]较深,即所在链的顶端比较靠下的节点,跳到它的top[]的父亲,这样就上升到了上边的链。
查询这一段路径的值(dfn[top[x]],dfn[x]),并重复操作,直到x,y在同一条链上为止。
同一链上,深度较小的点一定dfn[]较小,也就是在线段树上的编号较小。通过swap,使较浅的作为x,并查询(dfn[x],dfn[y])
int getquery(int x,int y) { int ans = 0; while(top[x] != top[y]) { if(dpth[top[x]] < dpth[top[y]]) swap(x,y); ans += query(dfn[top[x]],dfn[x],1,n,1); x = fa[top[x]]; } if(dpth[x] > dpth[y]) swap(x,y); ans += query(dfn[x],dfn[y],1,n,1); return ans; }
子树
根据dfs序,可以知道,一棵子树内的序号是连续的。
那么,对应到线段树上,即为(dfn[x],dfn[x]+siz[x]-1)。
完整代码如下
#include<iostream> #include<cstdio> #include<cstring> #define MogeKo qwq using namespace std; const int maxn = 1e5+10; int n,m,rt,opt,x,y,z,mod,cnt; int to[maxn<<1],head[maxn<<1],nxt[maxn<<1]; long long sum[maxn<<2],lazy[maxn<<2]; int dfn[maxn],dpth[maxn],siz[maxn],hson[maxn],fa[maxn],top[maxn],point[maxn]; long long w[maxn]; void add(int x,int y) { to[++cnt] = y; nxt[cnt] = head[x]; head[x] = cnt; } void dfs1(int u) { dpth[u] = dpth[fa[u]]+1; siz[u] = 1; for(int i = head[u]; i; i = nxt[i]) { int v = to[i]; if(v == fa[u])continue; fa[v] = u; dfs1(v); siz[u] += siz[v]; if(siz[v] > siz[hson[u]]) hson[u] = v; } } void dfs2(int u,int t) { dfn[u] = ++cnt; point[cnt] = u; top[u] = t; if(!hson[u])return; dfs2(hson[u],t); for(int i = head[u]; i; i = nxt[i]) { int v = to[i]; if(v == fa[u] || v == hson[u])continue; dfs2(v,v); } } void build(int l,int r,int now) { if(l == r) { sum[now] = w[point[l]] %mod; return; } int mid = l+r>>1; build(l,mid,now<<1),build(mid+1,r,now<<1|1); sum[now] = (sum[now<<1] + sum[now<<1|1]) %mod; } void pushdown(int l,int r,int now) { sum[now] += (r-l+1)*lazy[now]%mod; (lazy[now<<1] += lazy[now]) %=mod; (lazy[now<<1|1] += lazy[now]) %=mod; lazy[now] = 0; } void modify(int L,int R,int l,int r,int c,int now) { if(L == l && R == r) { lazy[now] += c; return; } (sum[now] += (R-L+1)*c ) %=mod; int mid = l+r>>1; if(R <= mid) modify(L,R,l,mid,c,now<<1); else if(L >= mid+1) modify(L,R,mid+1,r,c,now<<1|1); else modify(L,mid,l,mid,c,now<<1),modify(mid+1,R,mid+1,r,c,now<<1|1); } long long query(int L,int R,int l,int r,int now) { if(L == l && R == r) { return (sum[now]+lazy[now]*(r-l+1))%mod; } pushdown(l,r,now); int mid = l+r>>1; if(R <= mid) return query(L,R,l,mid,now<<1); else if(L >= mid+1) return query(L,R,mid+1,r,now<<1|1); else return (query(L,mid,l,mid,now<<1) + query(mid+1,R,mid+1,r,now<<1|1)) %mod; } void getmodify(int x,int y,int c) { while(top[x] != top[y]) { if(dpth[top[x]] < dpth[top[y]]) swap(x,y); modify(dfn[top[x]],dfn[x],1,n,c,1); x = fa[top[x]]; } if(dpth[x] > dpth[y]) swap(x,y); modify(dfn[x],dfn[y],1,n,c,1); } long long getquery(int x,int y) { long long ans = 0; while(top[x] != top[y]) { if(dpth[top[x]] < dpth[top[y]]) swap(x,y); (ans += query(dfn[top[x]],dfn[x],1,n,1)) %=mod; x = fa[top[x]]; } if(dpth[x] > dpth[y]) swap(x,y); (ans += query(dfn[x],dfn[y],1,n,1)) %=mod; return ans; } int main() { scanf("%d%d%d%d",&n,&m,&rt,&mod); for(int i = 1; i <= n; i++) scanf("%lld",&w[i]); for(int i = 1; i <= n-1; i++) { scanf("%d%d",&x,&y); add(x,y),add(y,x); } cnt = 0; dfs1(rt),dfs2(rt,rt); build(1,n,1); for(int i = 1; i <= m; i++) { scanf("%d",&opt); if(opt == 1) { scanf("%d%d%d",&x,&y,&z); getmodify(x,y,z%mod); } if(opt == 2) { scanf("%d%d",&x,&y); printf("%lld ",getquery(x,y)); } if(opt == 3) { scanf("%d%d",&x,&z); modify(dfn[x],dfn[x]+siz[x]-1,1,n,z%mod,1); } if(opt == 4) { scanf("%d",&x); printf("%lld ",query(dfn[x],dfn[x]+siz[x]-1,1,n,1)); } } return 0; }