zoukankan      html  css  js  c++  java
  • [学习笔记]吉司机线段树

    引入

    • 经典问题:给定一个序列,支持区间取 (min)(给定 (l,r,x) 把所有满足 (lle ile r)(a_i) 改成 (min(a_i,x)))和区间求和

    • 要求在 (O((n+q)log n)) 的时间内解决

    算法流程

    • 吉司机线段树是一种势能线段树,可以实现区间取 (min/max) 区间求和

    • (min) 为例,线段树上每个节点维护四个值:

    • (1)(mx):区间最大值

    • (2)(cnt):区间最大值的出现次数

    • (3)(md):区间次大值(严格小于最大值且最大的数)

    • (4)(sum):区间和

    • 实现区间取 (min) 时,递归到线段树上一个包含于询问区间的节点 (p) 时,进行如下处理:

    • (1)若 (xge mx_p),则显然这次修改不影响节点 (p),直接 return

    • (2)若 (xle md_p),则暴力往 (p) 的左右子节点递归

    • (3)否则 (md_p<x<mx_p),这次修改对 (sum_p) 可以计算出为 (cnt_p imes(mx_p-x)),打标记即可

    • 考虑说明这个算法的复杂度。

    • 注意到区间取 (min) 只会把区间内不同的数逐渐变成相同的

    • 易得对于线段树上一个节点 (p),一次修改操作最多只会让节点 (p) 代表的区间内不同数的种类数增加 (1)

    • 而修改操作时被暴力递归到的节点代表的区间内,不同数的种类数必然有减少

    • 也就是说,所有节点不同数种类数之和最多增加 (O(qlog n)) 次,必然也最多减少 (O((n+q)log n))

    • 故总时间复杂度 (O((n+q)log n))

    与区间加的结合

    • 吉司机线段树也可以和区间加相结合,只需多一个加法标记即可

    • 不过值得注意的是:如果有区间加操作则复杂度要多一个 (log)

    • 分析略

    CF1290E Cartesian Tree

    题意

    • 给定一个 (n) 元排列 (p)

    • 对于每个 (1le ile n),求出这个排列中所有 (le i) 的数构成的序列(维持相对位置关系)的笛卡尔树(是大根堆)所有节点的子树大小之和

    • (1le nle 150000)

    做法

    • 对于一个 (n) 个元素的序列 (a),其笛卡尔树上节点 (i) 的子树大小为 (nxt_i-pre_i-1)(pre_i)(nxt_i) 分别为 (a_i) 左边和右边第一个大于 (a_i) 的位置,如果不存在则分别为 (0)(n+1)

    • 于是 (i) 从小到大维护所有 (pre)(nxt) 的和即可

    • 但为了方便维护,我们不动态维护序列,而是在一个长度为 (n) 的序列上把所有数从小到大激活,并且把 (pre)(nxt) 重新定义:

    • (i) 的左边第一个比它大的被激活的位置为 (j),则 (pre_i) 等于 (j) 及其右边被激活的总个数,如果 (j) 不存在则 (pre_i) 为所有被激活的个数加 (1)

    • (i) 的右边第一个比它大的被激活的位置为 (j),则 (nxt_i) 等于 (j) 及其左边被激活的总个数,如果 (j) 不存在则 (nxt_i) 为所有被激活的个数加 (1)

    • 这样可以算出如果已经加入了前 (i) 小的数则答案为 (sum_j(pre_j+nxt_j)-i^2-2i)

    • 当第 (i) 次在序列上第 (x) 位的数被激活(也就是 (p_x=i))时:

    • (pre_x=nxt_x=i+1),可以直接处理

    • (y<x),则 (nxt_y) 的值应当对 (c)(min)(c) 为位置 (x) 左边被激活的数的个数,包括 (x) 本身)

    • (y>x),则 (nxt_y) 应当加 (1)

    • 也就是要实现区间取 (min) 和区间加,吉司机线段树维护即可,(pre) 同理

    • (O(nlog^2n))

    Code

    #include <bits/stdc++.h>
    #define p2 p << 1
    #define p3 p << 1 | 1
    
    template <class T>
    inline void read(T &res)
    {
    	res = 0; bool bo = 0; char c;
    	while (((c = getchar()) < '0' || c > '9') && c != '-');
    	if (c == '-') bo = 1; else res = c - 48;
    	while ((c = getchar()) >= '0' && c <= '9')
    		res = (res << 3) + (res << 1) + (c - 48);
    	if (bo) res = ~res + 1;
    }
    
    template <class T>
    inline T Max(const T &a, const T &b) {return a > b ? a : b;}
    
    template <class T>
    inline T Min(const T &a, const T &b) {return a < b ? a : b;}
    
    typedef long long ll;
    
    const int N = 15e4 + 5, M = 6e5 + 5, INF = 0x3f3f3f3f;
    
    int n, p[N], A[N];
    
    void change(int x, int v)
    {
    	for (; x <= n; x += x & -x)
    		A[x] += v;
    }
    
    int ask(int x)
    {
    	int res = 0;
    	for (; x; x -= x & -x) res += A[x];
    	return res;
    }
    
    struct elem
    {
    	int mx, cnt, md;
    	
    	friend inline elem operator + (elem a, elem b)
    	{
    		elem res;
    		if (a.mx > b.mx) res = a, res.md = Max(a.md, b.mx);
    		else if (a.mx < b.mx) res = b, res.md = Max(a.mx, b.md);
    		else res = a, res.cnt = a.cnt + b.cnt, res.md = Max(a.md, b.md);
    		return res;
    	}
    };
    
    struct seg
    {
    	elem T[M]; int add[M], tag[M], cnt[M]; ll sum[M];
    	
    	void build(int l, int r, int p)
    	{
    		T[p] = (elem) {-INF, 0, -INF}; add[p] = 0; tag[p] = INF;
    		if (l == r) return;
    		int mid = l + r >> 1;
    		build(l, mid, p2); build(mid + 1, r, p3);
    	}
    	
    	void down(int p)
    	{
    		T[p2].mx += add[p]; T[p2].md += add[p]; sum[p2] += 1ll * cnt[p2] * add[p];
    		add[p2] += add[p]; tag[p2] += add[p];
    		T[p3].mx += add[p]; T[p3].md += add[p]; sum[p3] += 1ll * cnt[p3] * add[p];
    		add[p3] += add[p]; tag[p3] += add[p];
    		if (tag[p] < T[p2].mx) sum[p2] -= 1ll * T[p2].cnt * (T[p2].mx - tag[p]),
    			tag[p2] = T[p2].mx = tag[p];
    		if (tag[p] < T[p3].mx) sum[p3] -= 1ll * T[p3].cnt * (T[p3].mx - tag[p]),
    			tag[p3] = T[p3].mx = tag[p];
    		add[p] = 0; tag[p] = INF;
    	}
    	
    	void upt(int p)
    	{
    		cnt[p] = cnt[p2] + cnt[p3]; sum[p] = sum[p2] + sum[p3];
    		T[p] = T[p2] + T[p3];
    	}
    	
    	void unlock(int l, int r, int pos, int v, int p)
    	{
    		if (l == r) return (void) (cnt[p] = T[p].cnt = 1, T[p].mx = sum[p] = v);
    		int mid = l + r >> 1; down(p);
    		if (pos <= mid) unlock(l, mid, pos, v, p2);
    		else unlock(mid + 1, r, pos, v, p3);
    		upt(p);
    	}
    	
    	void change(int l, int r, int s, int e, int v, int p)
    	{
    		if (e < l || s > r) return;
    		if (s <= l && r <= e) return (void) (add[p] += v, tag[p] += v,
    			T[p].mx += v, T[p].md += v, sum[p] += 1ll * cnt[p] * v);
    		int mid = l + r >> 1; down(p);
    		change(l, mid, s, e, v, p2); change(mid + 1, r, s, e, v, p3);
    		upt(p);
    	}
    	
    	void modify(int l, int r, int s, int e, int v, int p)
    	{
    		if (e < l || s > r) return;
    		if (s <= l && r <= e && v > T[p].md)
    		{
    			if (v < T[p].mx) sum[p] -= 1ll * T[p].cnt * (T[p].mx - v),
    				tag[p] = T[p].mx = v;
    			return;
    		}
    		int mid = l + r >> 1; down(p);
    		modify(l, mid, s, e, v, p2); modify(mid + 1, r, s, e, v, p3);
    		upt(p);
    	}
    } T1, T2;
    
    int main()
    {
    	int x;
    	read(n);
    	for (int i = 1; i <= n; i++) read(x), p[x] = i;
    	T1.build(1, n, 1); T2.build(1, n, 1);
    	for (int i = 1; i <= n; i++)
    	{
    		change(p[i], 1); int c = ask(p[i]);
    		T1.modify(1, n, 1, p[i], c, 1); T1.change(1, n, p[i], n, 1, 1);
    		T2.modify(1, n, p[i], n, i - c + 1, 1); T2.change(1, n, 1, p[i], 1, 1);
    		T1.unlock(1, n, p[i], i + 1, 1); T2.unlock(1, n, p[i], i + 1, 1);
    		printf("%lld
    ", T1.sum[1] + T2.sum[1] - 1ll * i * (i + 2));
    	}
    	return 0;
    }
    

    区间 K 小值

    题意

    • 一个长度为 (n) 的序列,所有数都是 ([1,n]) 内的正整数

    • (n) 次操作,每次操作为区间取 (min) 或求区间第 (k) 小值

    • (1le n,mle 8 imes10^4)

    做法

    • 容易想到位置线段树套权值线段树

    • 查询时把询问区间拆成的 (O(log n)) 个外层节点拎出来,在这些节点内嵌的线段树上二分即可

    • 对于修改,外层的线段树是可以打上取 (min) 的标记的,但这里存在的问题是:对于打上标记的节点,其祖先节点内嵌的线段树无法快速实现修改。

    • 于是考虑和吉司机线段树一样的思想:把相同的数放到一起修改。

    • 修改时对于在外层线段树上拆出的所有 (O(log n)) 个节点,找出这个节点上所有本质不同的大于 (x) 的数,在这个节点及其祖先上把这些数改为 (x) 即可

    • 注意到每把这个节点内本质不同的数的个数减 (1),都要在 (O(log n)) 棵内层线段树上做修改,也就是需要 (O(log^2n)) 的复杂度

    • 故总复杂度 (O((n+q)log^3n))

    Code

    #include <bits/stdc++.h>
    #define p2 p << 1
    #define p3 p << 1 | 1
    
    template <class T>
    inline void read(T &res)
    {
    	res = 0; bool bo = 0; char c;
    	while (((c = getchar()) < '0' || c > '9') && c != '-');
    	if (c == '-') bo = 1; else res = c - 48;
    	while ((c = getchar()) >= '0' && c <= '9')
    		res = (res << 3) + (res << 1) + (c - 48);
    	if (bo) res = ~res + 1;
    }
    
    template <class T>
    inline T Min(const T &a, const T &b) {return a < b ? a : b;}
    
    const int N = 8e4 + 5, M = N << 2, L = 3e7 + 5;
    
    int n, m, a[N], rt[M], tag[M], ToT, tot, pt[N], nc, del[L], qaq, pos[N], val[N];
    
    struct seg
    {
    	int lc, rc, sum;
    } T[L];
    
    inline int newnode() {return nc ? del[nc--] : ++ToT;}
    
    inline void delnode(int p) {T[p].lc = T[p].rc = T[p].sum = 0; del[++nc] = p;}
    
    void change(int l, int r, int pos, int v, int &p)
    {
    	if (!v) return;
    	if (!p) p = newnode(); T[p].sum += v;
    	if (l == r) return;
    	int mid = l + r >> 1;
    	if (pos <= mid) change(l, mid, pos, v, T[p].lc);
    	else change(mid + 1, r, pos, v, T[p].rc);
    }
    
    void dfs(int p)
    {
    	if (!p) return;
    	dfs(T[p].lc); dfs(T[p].rc);
    	delnode(p);
    }
    
    int ask(int l, int r, int x, int p)
    {
    	if (l == r) return 0;
    	int mid = l + r >> 1, res;
    	if (x <= mid) res = ask(l, mid, x, T[p].lc) + T[T[p].rc].sum,
    		dfs(T[p].rc), T[p].rc = 0;
    	else res = ask(mid + 1, r, x, T[p].rc);
    	return T[p].sum = T[T[p].lc].sum + T[T[p].rc].sum, res;
    }
    
    void zzq(int x, int &p) {change(1, n, x, ask(1, n, x, p), p);}
    
    void build(int l, int r, int p)
    {
    	tag[p] = n;
    	for (int i = l; i <= r; i++) change(1, n, a[i], 1, rt[p]);
    	if (l == r) return;
    	int mid = l + r >> 1;
    	build(l, mid, p2); build(mid + 1, r, p3);
    }
    
    void down(int p)
    {
    	tag[p2] = Min(tag[p2], tag[p]);
    	tag[p3] = Min(tag[p3], tag[p]);
    	zzq(tag[p], rt[p2]); zzq(tag[p], rt[p3]);
    	tag[p] = n;
    }
    
    void zhouzhouzka(int l, int r, int x, int p)
    {
    	if (x > r || !T[p].sum) return;
    	if (l == r) return (void) (pos[++qaq] = l, val[qaq] = T[p].sum);
    	int mid = l + r >> 1;
    	zhouzhouzka(l, mid, x, T[p].lc);
    	zhouzhouzka(mid + 1, r, x, T[p].rc);
    }
    
    void getmin(int l, int r, int s, int e, int x, int p)
    {
    	if (e < l || s > r) return;
    	if (s <= l && r <= e)
    	{
    		tag[p] = Min(tag[p], x); qaq = 0; zhouzhouzka(1, n, x + 1, rt[p]);
    		zzq(x, rt[p]);
    		for (int i = 1; i <= qaq; i++)
    			for (int q = p >> 1; q; q >>= 1)
    				change(1, n, pos[i], -val[i], rt[q]),
    					change(1, n, x, val[i], rt[q]);
    		return;
    	}
    	int mid = l + r >> 1; down(p);
    	getmin(l, mid, s, e, x, p2); getmin(mid + 1, r, s, e, x, p3);
    }
    
    void czx(int l, int r, int s, int e, int x, int p)
    {
    	if (e < l || s > r) return;
    	if (s <= l && r <= e) return (void) (pt[++tot] = rt[p]);
    	int mid = l + r >> 1; down(p);
    	czx(l, mid, s, e, x, p2); czx(mid + 1, r, s, e, x, p3);
    }
    
    int query(int l, int r, int k)
    {
    	tot = 0; czx(1, n, l, r, n, 1);
    	int d = 0;
    	for (int i = 1; i <= tot; i++) d += T[pt[i]].sum;
    	l = 1; r = n;
    	while (l < r)
    	{
    		int delta = 0, mid = l + r >> 1;
    		for (int i = 1; i <= tot; i++) delta += T[T[pt[i]].lc].sum;
    		if (k <= delta)
    		{
    			r = mid;
    			for (int i = 1; i <= tot; i++) pt[i] = T[pt[i]].lc;
    		}
    		else
    		{
    			k -= delta; l = mid + 1;
    			for (int i = 1; i <= tot; i++) pt[i] = T[pt[i]].rc;
    		}
    	}
    	return l;
    }
    
    int main()
    {
    	int op, l, r, x;
    	read(n); read(m);
    	for (int i = 1; i <= n; i++) read(a[i]);
    	build(1, n, 1);
    	while (m--)
    	{
    		read(op); read(l); read(r); read(x);
    		if (op == 1 && x > n) x = n;
    		if (op == 1) getmin(1, n, l, r, x, 1);
    		else printf("%d
    ", query(l, r, x));
    	}
    	return 0;
    }
    
  • 相关阅读:
    6月16日
    9月15日
    9月14日
    9月13日
    9月12日
    6月11日
    梦断代码阅读笔记
    11周总结
    梦断代码阅读笔记
    10第一阶段意见评论
  • 原文地址:https://www.cnblogs.com/xyz32768/p/12590112.html
Copyright © 2011-2022 走看看