zoukankan      html  css  js  c++  java
  • 树链剖分(模板)

    树链剖分模板及拓展

    模板

    code

    //操作 1: 格式: 1 x y z 表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z。
    //操作 2: 格式: 2 x y表示求树从 x 到 y 结点最短路径上所有节点的值之和。
    //操作 3: 格式: 3 x z表示将以 x 为根节点的子树内所有节点值都加上 z。
    //操作 4: 格式: 4 x表示求以 x 为根节点的子树内所有节点值之和
    #include <bits/stdc++.h>
    
    using namespace std;
    
    #define ll long long
    #define ls(x) (x << 1)
    #define rs(x) (x << 1 | 1)
    
    int n, m, rt, tot, cnt, id[100005], top[100005], nw[100005], f[100005], dep[100005], sz[100005], mson[100005], a[100005], hd[100005], to[200005], nxt[200005];
    
    ll mod;
    
    struct node
    {
    	int l, r;
    	ll sum, add;
    }t[400005];
    
    int read()
    {
    	int x = 0, fl = 1; char ch = getchar();
    	while (ch < '0' || ch > '9') { if (ch == '-') fl = -1; ch = getchar();}
    	while (ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + ch - '0'; ch = getchar();}
    	return x * fl;
    }
    
    void add(int x, int y)
    {
    	tot ++ ;
    	to[tot] = y;
    	nxt[tot] = hd[x];
    	hd[x] = tot;
    	return;
    }
    
    void push_up(int p)
    {
    	t[p].sum = (t[ls(p)].sum + t[rs(p)].sum) % mod;
    	return;
    }
    
    void push_down(int p)
    {
    	if (!t[p].add) return;
    	t[ls(p)].add = (t[ls(p)].add + t[p].add) % mod;
    	t[rs(p)].add = (t[rs(p)].add + t[p].add) % mod;
    	t[ls(p)].sum = (t[ls(p)].sum + 1ll * t[p].add * (t[ls(p)].r - t[ls(p)].l + 1) % mod) % mod;
    	t[rs(p)].sum = (t[rs(p)].sum + 1ll * t[p].add * (t[rs(p)].r - t[rs(p)].l + 1) % mod) % mod;
    	t[p].add = 0;
    	return;
    }
    
    void update(int p, int l0, int r0, int d)
    {
    	if (l0 <= t[p].l && t[p].r <= r0)
    	{
    		t[p].add = (t[p].add + (ll)d) % mod;
    		t[p].sum = (t[p].sum + 1ll * (t[p].r - t[p].l + 1) * d % mod) % mod;
    		return;
    	}
    	push_down(p);
    	int mid = (t[p].l + t[p].r) >> 1;
    	if (l0 <= mid) update(ls(p), l0, r0, d);
    	if (r0 > mid) update(rs(p), l0, r0, d);
    	push_up(p);
    	return;
    }
    
    ll query(int p, int l0, int r0)
    {
    	if (l0 <= t[p].l && t[p].r <= r0) return t[p].sum;
    	push_down(p);
    	int mid = (t[p].l + t[p].r) >> 1; ll res = 0ll;
    	if (l0 <= mid) res = (res + query(ls(p), l0, r0) % mod) % mod;
    	if (r0 > mid) res = (res + query(rs(p), l0, r0) % mod) % mod;
    	return res % mod;
    }
    
    void build(int p, int l0, int r0)
    {
    	t[p].l = l0; t[p].r = r0;
    	if (l0 == r0)
    	{
    		t[p].sum = (ll)(nw[l0]);
    		return;
    	}
    	int mid = (l0 + r0) >> 1;
    	build(ls(p), l0, mid);
    	build(rs(p), mid + 1, r0);
    	push_up(p);
    	return;
    }
    
    void dfs1(int x, int fa)
    {
    	sz[x] = 1;
    	int mx = -1;
    	for (int i = hd[x]; i; i = nxt[i])
    	{
    		int y = to[i];
    		if (y == fa) continue;
    		dep[y] = dep[x] + 1;
    		f[y] = x;
    		dfs1(y, x);
    		sz[x] += sz[y];
    		if (sz[y] > mx)
    		{
    			mx = sz[y];
    			mson[x] = y;
    		}
    	}
    	return;
    }
    
    void dfs2(int x, int tp)
    {
    	id[x] = ++ cnt;
    	top[x] = tp;
    	nw[cnt] = a[x];
    	if (!mson[x]) return;
    	dfs2(mson[x], tp);
    	for (int i = hd[x]; i; i = nxt[i])
    	{
    		int y = to[i];
    		if (y == f[x] || y == mson[x]) continue;//特别注意不是if(y == tp || y == mson[x])
    		dfs2(y, y);
    	}
    	return;
    }
    
    void q1()
    {
    	int x = read(), y = read(), z = read();
    	while (top[x] != top[y])
    	{
    		if (dep[top[x]] < dep[top[y]]) swap(x, y);//特别注意这里不是if (dep[x] < dep[y])
    		update(1, id[top[x]], id[x], (ll)z);
    		x = f[top[x]];
    	}
    	if (dep[x] > dep[y]) swap(x, y);
    	update(1, id[x], id[y], (ll)z);
    	return;
    }
    
    void q2()
    {
    	int x = read(), y = read();
    	ll res = 0ll;
    	while (top[x] != top[y])
    	{
    		if (dep[top[x]] < dep[top[y]]) swap(x, y);
    		res = (res + query(1, id[top[x]], id[x])) % mod;
    		x = f[top[x]];
    	}
    	if (dep[x] > dep[y]) swap(x, y);
    	res = (res + query(1, id[x], id[y])) % mod;
    	printf("%lld
    ", res);
    	return;
    }
    
    void q3()
    {
    	int x = read(), z = read();
    	update(1, id[x], id[x] + sz[x] - 1, (ll)z);
    	return;
    }
    
    void q4()
    {
    	int x = read();
    	printf("%lld
    ", query(1, id[x], id[x] + sz[x] - 1) % mod);
    	return;
    }
    
    int main()
    {
    	n = read(); m = read(); rt = read(); mod = (ll)read();
    	for (int i = 1; i <= n; i ++ ) a[i] = read();
    	for (int i = 1; i <= n - 1; i ++ )
    	{
    		int x = read(), y = read();
    		add(x, y); add(y, x);
    	}
    	dfs1(rt, 0); dfs2(rt, rt);
    	build(1, 1, n);
    	while (m -- )
    	{
    		int opt = read();
    		if (opt == 1) q1();
    		else if (opt == 2) q2();
    		else if (opt == 3) q3();
    		else q4();
    	}
    	return 0;
    }
    

    拓展

    学了一个换根树剖。主要是解决多次换根的问题,其他和普通树剖差不多。

    • 换根链上操作:不变
    • 换根子树操作:如果新的根(r)(x)的子树内,(x)新的子树就是除去(x)(r)方向的那个子树外,所有的节点
    • 换根后求(lca)(lca'=lca(x,r)|lca(y,r)|lca(x,y)[dep=max])
    • 题目见P3979 遥远的国度CF916E Jamie and Tree
  • 相关阅读:
    融云技术分享:全面揭秘亿级IM消息的可靠投递机制
    企业微信的IM架构设计揭秘:消息模型、万人群、已读回执、消息撤回等
    喜马拉雅亿级用户量的离线消息推送系统架构设计实践
    uni-app 项目使用 npm 包中的小程序自定义组件
    Leetcode563. 二叉树的坡度
    vue看源码遇到 报错
    windows使用
    mysql 显示行号
    从优秀到卓越
    【面经】阿里蚂蚁金服22秋招-Java后端
  • 原文地址:https://www.cnblogs.com/andysj/p/13948234.html
Copyright © 2011-2022 走看看