zoukankan      html  css  js  c++  java
  • 猫树学习笔记

    本文参考自算法发明者 immortalCO(猫锟) 的博客 一种高效处理无修改区间或树上询问的数据结构(附代码)

    感谢 猫锟 提供了对于一类题比较通用的解决办法,以及思路启发。

    问题描述

    给出一个某种元素的序列 (a_1,a_2,dots ,a_n),要求进行 (m) 次询问,每一次是询问一段区间 ([l,r]) 的某种支持结合律和快速合并的信息,要求在线。

    这类问题比较通用,比如在 DP 的优化中就常常见到。

    算法实现

    算法介绍

    对于常规问题,比如区间最值,区间最大子段和。我们常常能用线段树等数据结构达到,构造 (O(n)) ,询问 (O(log n)) 的时间复杂度。

    对于这些做法,只有一点不好,询问复杂度 不够优秀,且对于一些特定问题,线段树的 push_up 合并也不好写。

    但对于区间最值这类的问题, 我们可以类似 (RMQ) 那样,在一般的问题上,以预处理的时间和空间,换取快速的询问

    我们首先考虑询问一个区间 ([q_l, q_r]) 。如果 (q_l = q_r) ,就可以直接得到答案。否则会不断在线段树上定位,而且会在几个区间的中点 (mid) 处被分开。

    我习惯于线段树每个区间维护的一个闭区间 ([l, r]) ,其中中点 (displaystyle mid = lfloor frac{l + r}{2} floor)

    考虑第一次被分开的位置,假设为 (p) 。那么原来的区间 ([q_l, q_r]) 就被分为 ([q_l, mid]) ,与 ([mid + 1, q_r])

    我们考虑对于每一个 (mid) 预处理他向前的后缀 ([i, mid]) 的答案(其中 (l le i le mid)) ,以及他向后的前缀 ([mid + 1, j]) (其中 (mid + 1 le j le r))。

    如果我们知道了 (p) 点所在的位置,我们可以直接利用之前预处理的 ([q_l, mid]) 以及 ([mid + 1, q_r]) 的答案直接合并即可。

    不难发现预处理的复杂度是 (O(n log n)) 的(对于每一层每个数刚好被考虑一次 (O(n) imes O(log n) = O(n log n))


    然后怎么快速知道这个位置呢?

    不难发现这个 (p) 的位置,就是线段树上对应 ([l, l])([r, r]) 节点的 (lca) (最近公共祖先)所处的位置,我们可以考虑用 (st) 表预处理,然后可以直接查询 (p) 的位置,但这样显然太麻烦了。

    如果整棵线段树满足堆式存储(也就是对于点 (i) ,它的两个儿子分别为 (2 i)(2i + 1) ),就有一个很好的性质。

    对于任意两个同深度的点,他们的 (lca) 是他们二进制下 (lcp) (最长公共前缀)。

    这个是十分显然的,因为对于两个深度相同的点,他们第一次分开的位置,必然导致当前最后一位的二进制不同,而前面都是相同的。

    我们把整棵树建成一个满二叉树 ([1, 2^k]),那么对于任意一个区间 ([i, i]) 都是满足他们的深度是最深且在同一层的。

    注意对于不同深度的点不一定满足这个情况!!这就是为什么我们为什么要建满的原因。

    然后对于两个数 (x, y) 的二进制下的 (lcp)x >> Log2[x ^ y] 。(这个很显然,丢掉第一个不相同的位后面的所有位就行了)

    这样我们就可以实现询问 (O(1)) 啦。

    我们称这个数据结构为 猫树

    算法本质

    看了一下 UOJ 评论区。。。

    其实就是将分治进行离线,我们用一个东西来存储这个分治结构,以及前面按位置分治的答案。

    所以这个算法最重要的还是,寻找特定问题的分治方案。

    例题讲解

    区间最大子段和

    题意

    给你一个序列 (a_i) ,有 (m) 次询问,每次询问一个区间 ([l, r]) ,表示询问这段区间的最大子段和。

    题解

    如果没有区间,那么这个是个经典的分治问题。

    最大子段和,要么完全在左区间,要么完全在右区间,要么跨越中点。

    所以我们只需要预处理 ([i, mid]) 的最大前缀和与 ([mid + 1, j]) 的最大前缀和,这个一边遍历一边取 (max)

    以及这两个区间的最大子段和。至于那个最大子段和,可以利用前缀和相减,保存一个前缀和最小值就行了。

    这个算法比标准线段树上合并信息,要好写并且更快。

    代码

    对于第一道题,还是建议看看代码怎么写的。。(瓶颈在输入输出上也是没谁啦)

    #include <bits/stdc++.h>
    
    #define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
    #define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
    #define Set(a, v) memset(a, v, sizeof(a))
    #define Cpy(a, b) memcpy(a, b, sizeof(a))
    #define debug(x) cout << #x << ": " << x << endl
    #define DEBUG(...) fprintf(stderr, __VA_ARGS__)
    
    using namespace std;
    
    inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
    inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;}
    
    inline int read() {
    	int x = 0, fh = 1; char ch = getchar();
    	for (; !isdigit(ch); ch = getchar()) if (ch == '-') fh = -1;
    	for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
    	return x * fh;
    }
    
    inline void Out(int x) {
    	static char sta[18], top, flag = false;
    	if (!x) { puts("0"); return ; }
    	sta[top = 1] = '
    ';
    	if (x < 0) flag = true, x = -x;
    	for (; x; x /= 10) sta[++ top] = (x % 10) + 48;
    	if (flag) putchar ('-'), flag = false;
    	while (top) putchar (sta[top --]);
    }
    
    void File() {
    #ifdef zjp_shadow
    	freopen ("1043.in", "r", stdin);
    	freopen ("1043.out", "w", stdout);
    #endif
    }
    
    const int N = 50100, Maxn = N * 4, MaxLog = 20, inf = 0x7f7f7f7f;
    
    int pos[Maxn], val[Maxn], Log2[Maxn], maxlen;
    
    namespace CatTree {
    
    	inline int Max(int a, int b) { return a > b ? a : b; }
    
    	int Pre[MaxLog][Maxn], Sum[MaxLog][Maxn];
    
    	void Build(int o, int l, int r, int dep) {
    		if (l == r) { pos[l] = o; return ; }
    		int mid = (l + r) >> 1, sum, minv;
    
    		Sum[dep][mid] = Pre[dep][mid] = sum = minv = val[mid]; chkmin(minv, 0);
    		Fordown(i, mid - 1, l) {
    			Pre[dep][i] = Max(Pre[dep][i + 1], sum += val[i]);
    			Sum[dep][i] = Max(Sum[dep][i + 1], sum - minv);
    			chkmin(minv, sum);
    		}
    
    		Sum[dep][mid + 1] = Pre[dep][mid + 1] = sum = minv = val[mid + 1]; chkmin(minv, 0);
    		For (i, mid + 2, r) {
    			Pre[dep][i] = Max(Pre[dep][i - 1], sum += val[i]);
    			Sum[dep][i] = Max(Sum[dep][i - 1], sum - minv);
    			chkmin(minv, sum);
    		}
    
    		Build(o << 1, l, mid, dep + 1);
    		Build(o << 1 | 1, mid + 1, r, dep + 1);
    	}
    
    	inline int Query(int l, int r) {
    		if (l == r) return val[l];
    		register int dep = Log2[pos[l]] - Log2[pos[l] ^ pos[r]];
    		return Max(Max(Sum[dep][l], Sum[dep][r]), Pre[dep][l] + Pre[dep][r]);
    	}
    
    }
    
    int n;
    
    int main() {
    
    	File();
    
    	n = read();
    	For (i, 1, n) val[i] = read();
    
    	for (maxlen = 1; maxlen < n; maxlen <<= 1);
    	CatTree :: Build(1, 1, maxlen, 1);
    
    
    	For (i, 2, maxlen << 1) Log2[i] = Log2[i >> 1] + 1;
    
    	for (register int m = read(), l, r; m; -- m) {
    		l = read(), r = read();
    		Out(CatTree :: Query(l, r));
    	}
    
    #ifdef zjp_shadow
    	cerr << (double) clock() / CLOCKS_PER_SEC << endl;
    #endif 
    
    	return 0;
    
    }
    
  • 相关阅读:
    matplotlib formatters
    matplotlib locators
    mysql> 12 simple but staple commands
    mysql--> find your databases' local position
    ubuntu16.04安装caffe常见问题及其解决方案
    gitlab使用说明
    vim配置摘要
    shell 提示符个性化设置
    python拼接参数不确定的SQL时防注入问题--条件语句最后拼入
    python_opencv ——图片预处里(二)
  • 原文地址:https://www.cnblogs.com/zjp-shadow/p/9377742.html
Copyright © 2011-2022 走看看