树链剖分学习笔记
序
听 Singercoder 说树链剖分是码农题,然后我否认(雾)。夜已深,屏幕微凉,数据结构,他说:“A了我,快!”,莫名激动。
我有一壶酒,足以慰风尘,还是,我心中好不了的伤疤。
正文
预备知识:
- LCA 的思想
- vector 或 链式前向星的熟练使用
- 线段树(线段树1
- dfs 序
模板题目:P3384
要求给你一颗树,有以下几个操作
- 操作 1: 格式: (1 x y z) 表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z。
- 操作 2: 格式: (2 x y) 表示求树从 x 到 y 结点最短路径上所有节点的值之和。
- 操作 3: 格式: (3 x z) 表示将以 x 为根节点的子树内所有节点值都加上 z。
- 操作 4: 格式: (4 x) 表示求以 x 为根节点的子树内所有节点值之和
树链剖分,顾名思义是在树上把他剖成链。至于题目叫重链剖分,因为我们又几个定义:
- 重儿子:这个点所有儿子节点为根的子树有最多节点的的那个儿子节点为重儿子
- 轻儿子:其他的叫做轻儿子
- 重边:连接重儿子和根的
- 轻边:其他的
- 重链:相邻重边连起来的连接一条重儿子的链叫重链,对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链,每一条重链以轻儿子为起点
然后思想就是用两次 dfs 做出重链,然后根据 dfs 序,建立线段树,然后这样以上的就可以做区间修改了。
两次 dfs 没啥说的吧。第一次做出重儿子之类的,第二次做 dfs 序。
对于每个操作的
操作一:
不放图了,我觉得好理解。首先 (x, y) 应该分别在两个重链上,然后直接对这个点到重链的头的顶点开始操作即区间 (dfn[top[x]], dfn[x]),因为他们的坐标是连续的,至于为啥要挑一个深的重链(重链深浅看链头的 dep)然后十分显然,最后在如果在一个重链上,考虑对 ((dfn[x], dfn[y])) 进行区间操作
然后操作二和操作一等价。
操作三:
这不是显然的么?对于区间 (dfn[x], dfn[x] + sz[x] - 1) 操作即可,区间加
操作四和操作三等价。
不禁还是感叹,学的东西好多,我怕不是大龄选手。放弃了,刚学树链剖分感觉代码难度并不高,只是注意细节即可。
写代码时候遇见的问题
- 本来想用下
#define
来add_edge
但是会出莫名错误,所以这种的就不要#define
,但是rdi()
还是正常使用即可 - 取模多检查
- 再也不写
using namespace std;
了
代码
#include <bits/stdc++.h>
#define gc() std::getchar()
#define pc(i) std::putchar(i)
template <typename T>
inline T read()
{
T x = 0;
char ch = gc();
bool f = 0;
while(!std::isdigit(ch))
{
f = (ch == '-');
ch = gc();
}
while(std::isdigit(ch))
{
x = x * 10 + (ch - '0');
ch = gc();
}
return f ? -x : x;
}
template <typename T>
void put(T x)
{
if(x < 0)
{
x = -x;
pc('-');
}
if(x < 10) {
pc(x + 48);
return;
}
put(x / 10);
pc(x % 10 + 48);
return ;
}
#define vit std::vector <int>:: iterator
#define sit std::string:: iterator
#define vi std::vector <int>
#define lbd(i, j, k) lower_bound(i, j, k)
#define pii std::pair <int, int>
#define mkp(i, j) std::make_pair(i, j)
#define lowbit(i) (i & -i)
#define ispow(i) (i == lowbit(i))
#define rdi() read <int> ()
#define rdl() read <long long> ()
#define pti(i) put <int> (i), putchar('
')
#define ptl(i) put <long long> (i), putchar(' ')
#define For(i, j, k) for(int i = j; i <= k; ++i)
#define pub(i) push_back(i)
#define pob() pop_back(i)
#define DEBUG std::printf("Passing [%s] in LINE %d
", __FUNCTION__, __LINE__)
const int Maxn = 2004001;
std::vector <int> v[Maxn];
int sz[Maxn], dep[Maxn], fa[Maxn], n, m, r, p, son[Maxn], w[Maxn], dfn[Maxn], cnt, ww[Maxn], top[Maxn];
void dfs_1(int u, int f, int depth)
{
fa[u] = f;
dep[u] = depth;
sz[u] = 1;
for(auto to : v[u])
{
if(to == f) continue;
dfs_1(to, u, depth + 1);
sz[u] += sz[to];
if(sz[to] > sz[son[u]]) son[u] = to;
}
return void();
}
void dfs_2(int u, int topf)
{
w[dfn[u] = ++cnt] = ww[u];
top[u] = topf;
if(!son[u]) return void();
dfs_2(son[u], topf);
for(auto to : v[u])
{
if(to == fa[u] || to == son[u]) continue;
dfs_2(to, to);
}
return void();
}
class SegmentTree
{
private:
struct Node;
typedef Node* node;
struct Node
{
int l, r, sum, lazy;
};
protected:
Node segment[Maxn << 1];
int mod;
#define seg segment
inline int ls(int p) { return p << 1; }
inline int rs(int p) { return p << 1 | 1; }
inline void pushup(int x) { seg[x].sum = (seg[ls(x)].sum + seg[rs(x)].sum) % mod; }
inline void point_add(int k, int x) { seg[k].lazy += x; seg[k].sum = (seg[k].sum + (seg[k].r - seg[k].l + 1) * x) % mod; }
inline void down(int x)
{
point_add(ls(x), seg[x].lazy);
point_add(rs(x), seg[x].lazy);
seg[x].lazy = 0;
}
void _build(int k, int l, int r)
{
seg[k].l = l, seg[k].r = r;
if(l == r)
{
seg[k].sum = w[l];
return void();
}
int mid = (l + r) >> 1;
if(l <= mid) _build(ls(k), l, mid);
if(mid < r) _build(rs(k), mid + 1, r);
pushup(k);
return void();
}
int _query(int k, int l, int r)
{
if(l <= seg[k].l && seg[k].r <= r) return seg[k].sum;
if(seg[k].lazy) down(k);
int mid = (seg[k].l + seg[k].r) >> 1;
int res = 0;
if(l <= mid) res = (res + _query(ls(k), l, r)) % mod;
if(mid < r) res = (res + _query(rs(k), l, r)) % mod;
return res;
}
void _add(int k, int l, int r, int x)
{
// 区间 l, r 然后加 x
if(l <= seg[k].l && seg[k].r <= r)
{
point_add(k, x);
return void();
}
if(seg[k].lazy) down(k);
int mid = (seg[k].l + seg[k].r) >> 1;
if(l <= mid) _add(ls(k), l, r, x);
if(mid < r) _add(rs(k), l, r, x);
pushup(k);
}
public:
SegmentTree() { }
inline void build(int n, int p) { mod = p; return _build(1, 1, n); }
inline void add(int l, int r, int x) { return _add(1, l, r, x); }
inline int query(int l, int r) { return _query(1, l, r); }
}tree;
inline void add_edge(int i,int j)
{
v[i].push_back(j);
v[j].push_back(i);
}
int x, y, ans, k;
int main(int argc, char* argv[])
{
#ifdef _DEBUG
freopen("in.txt", "r", stdin);
#endif
n = rdi();
m = rdi();
r = rdi();
p = rdi();
For(i, 1, n) ww[i] = rdi();
For(i, 1, n - 1) add_edge(rdi(), rdi());
dfs_1(r, 0, 0);
dfs_2(r, r);
tree.build(n, p);
For(i, 1, m)
{
switch(rdi())
{
case 1:
x = rdi();
y = rdi();
k = rdi() % p;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) std::swap(x, y);
tree.add(dfn[top[x]], dfn[x], k);
x = fa[top[x]];
}
if(dep[x] > dep[y]) std::swap(x, y);
tree.add(dfn[x], dfn[y], k);
break;
case 2:
ans = 0;
x = rdi();
y = rdi();
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) std::swap(x, y);
ans = (ans + tree.query(dfn[top[x]], dfn[x])) % p;
x = fa[top[x]];
}
if(dep[x] > dep[y]) std::swap(x, y);
ans = (ans + tree.query(dfn[x], dfn[y])) % p;
pti(ans);
break;
case 3:
x = rdi();
k = rdi();
tree.add(dfn[x], dfn[x] + sz[x] - 1, k);
break;
case 4:
x = rdi();
pti(tree.query(dfn[x], dfn[x] + sz[x] - 1));
break;
}
}
return 0;
}
嵬
嵬再也没有了。
祭我逝去的青春