SegmentTreeBeats 简单学习笔记
有一天补 ( ext{CF}) 做到一个题,转化一波题意以后变成要求维护一个序列 (a) 。
-
对于 (i in [l,r], a_i =a_i+x) 。
-
对于 (i in [l,r], a_i =min(a_i, x)) 。
-
求 (sum_{i=l}^r a_i) 。
其实就是 ( ext{Segment Tree Beats}) 的模板题,也就是那年吉老师营员交流课件的例题, 用线段树维护区间最大值 (mx) ,区间次大值 (se) ,区间和 (sum) ,区间最大值出现次数 (cnt) ,加法标记 (tag) 。
对于第二种操作,如果一个区间 (mx leq x) 那么无事发生,可以跳过其所有子区间。如果 (se < x < mx) ,那么 (sum = sum - (mx-x) imes cnt, mx = x) ,注意这里子区间的 (mx, sum) 并没有更改,相当于 (mx) 同时作为一个修改标记,当前区间比子区间的 (mx) 小时,要进行 (sum = sum - (mx-mx[fa]) imes cnt) 的 ( ext{pushdown}) 操作,打完标记之后就可以跳过了。对于其它情况,暴力对其子区间求解。
到这一步位置算法流程不难理解,但算法的复杂度证明比较难懂,目前 (mathcal O(nlog n)) 的证明我还不会,只能理解 (mathcal O(n log^2 n)) 的证明,在这里写一下简要证明:
定义势能函数 (Phi) 为线段树中 (mx) 不等于其父亲节点 (mx) 的节点数量,考虑一次第二操作过程的任一终止节点 (v) 。如果 (v) 对 (Phi) 有贡献,假设这一类节点的数量为 (A) ,到达这些节点的复杂度为 (mathcal O (Alog n)) ,结束后这些节点都对势能没贡献了,也就是说用了 (mathcal O(Alog n)) 的时间让势能减小了 (A) 。
如果 (v) 对 (Phi) 没贡献,记 (u) 为 (v) 的父亲,(u) 的另外一儿子为 (c) ,那么 (mx[u] = mx[v], se[u] eq se[v]) ,也就是说 (se[u] = mx[c]) 。那么 (c) 的子树一定会被访问, 并在访问结束后 (c) 对 (Phi) 没有贡献,假设这一类节点数量为 (A) ,同样也用 (mathcal O(Alog n)) 的时间让势能减小了 (A) 。也就是说对于修改操作,实际上是每减小一个势能用了 (mathcal O(log n)) 的代价。
考虑修改操作,每次只会修改 (mathcal O(log n)) 节点,最多使势能增加 (mathcal O(log n)) 所以总复杂度是 (mathcal O(nlog^2 n))。
code: Codeforces 1290 E
/*program by mangoyang*/ #pragma GCC optimize("Ofast", "inline") #include<bits/stdc++.h> #define inf (0x7f7f7f7f) #define Max(a, b) ((a) > (b) ? (a) : (b)) #define Min(a, b) ((a) < (b) ? (a) : (b)) typedef long long ll; using namespace std; template <class T> inline void read(T &x){ int ch = 0, f = 0; x = 0; for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1; for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48; if(f) x = -x; } #define int ll const int N = 150005; int a[N], b[N], ans[N], n; namespace Seg{ #define lson (u << 1) #define rson (u << 1 | 1) #define mid ((l + r) >> 1) int mx[N<<2], se[N<<2], sz[N<<2], cnt[N<<2], sum[N<<2], tag[N<<2]; inline void clear(){ memset(mx, 0, sizeof(mx)); memset(se, 0, sizeof(se)); memset(sz, 0, sizeof(sz)); memset(cnt, 0, sizeof(cnt)); memset(sum, 0, sizeof(sum)); memset(tag, 0, sizeof(tag)); } inline void update(int u){ if(mx[lson] > mx[rson]) mx[u] = mx[lson], cnt[u] = cnt[lson]; else mx[u] = mx[rson], cnt[u] = cnt[rson]; if(mx[lson] == mx[rson]) cnt[u] += cnt[lson]; se[u] = max(se[lson], se[rson]); if(mx[lson] != mx[rson]){ int x = min(mx[lson], mx[rson]); se[u] = max(se[u], x); } sum[u] = sum[lson] + sum[rson]; sz[u] = sz[lson] + sz[rson]; } inline void pushdown(int u){ if(tag[u]){ if(mx[lson]) mx[lson] += tag[u]; if(se[lson]) se[lson] += tag[u]; if(mx[rson]) mx[rson] += tag[u]; if(se[rson]) se[rson] += tag[u]; sum[lson] += tag[u] * sz[lson]; sum[rson] += tag[u] * sz[rson]; tag[lson] += tag[u]; tag[rson] += tag[u]; tag[u] = 0; } if(mx[lson] > mx[u]){ sum[lson] -= (mx[lson] - mx[u]) * cnt[lson]; mx[lson] = mx[u]; } if(mx[rson] > mx[u]){ sum[rson] -= (mx[rson] - mx[u]) * cnt[rson]; mx[rson] = mx[u]; } } inline void ins(int u, int l, int r, int pos, int x){ if(l == r){ mx[u] = sum[u] = x; sz[u] = cnt[u] = 1; return; } pushdown(u); if(pos <= mid) ins(lson, l, mid, pos, x); else ins(rson, mid + 1, r, pos, x); update(u); } inline void gao(int u, int l, int r, int L, int R, int x){ if(l >= L && r <= R){ if(mx[u] <= x) return; if(se[u] < x){ sum[u] -= (mx[u] - x) * cnt[u]; mx[u] = x; return; } pushdown(u); gao(lson, l, mid, L, R, x); gao(rson, mid + 1, r, L, R, x); update(u); return; } pushdown(u); if(L <= mid) gao(lson, l, mid, L, R, x); if(mid < R) gao(rson, mid + 1, r, L, R, x); update(u); } inline void add(int u, int l, int r, int L, int R){ if(l >= L && r <= R){ if(mx[u]) mx[u]++; if(se[u]) se[u]++; sum[u] += sz[u], tag[u]++; return; } pushdown(u); if(L <= mid) add(lson, l, mid, L, R); if(mid < R) add(rson, mid + 1, r, L, R); update(u); } inline int query(int u, int l, int r, int L, int R){ if(l >= L && r <= R) return sz[u]; int res = 0; pushdown(u); if(L <= mid) res += query(lson, l, mid, L, R); if(mid < R) res += query(rson, mid + 1, r, L, R); return res; } } signed main(){ read(n); for(int i = 1; i <= n; i++) read(a[i]); for(int i = 1; i <= n; i++) b[a[i]] = i; for(int i = 1; i <= n; i++){ Seg::add(1, 1, n, b[i] + 1, n); int sz = Seg::query(1, 1, n, 1, b[i]); if(sz) Seg::gao(1, 1, n, 1, b[i], sz); Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1); ans[i] = Seg::sum[1] + Seg::sz[1]; } reverse(a + 1, a + n + 1); for(int i = 1; i <= n; i++) b[a[i]] = i; Seg::clear(); for(int i = 1; i <= n; i++){ Seg::add(1, 1, n, b[i] + 1, n); int sz = Seg::query(1, 1, n, 1, b[i]); if(sz) Seg::gao(1, 1, n, 1, b[i], sz); Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1); ans[i] -= Seg::sz[1] * (Seg::sz[1] + 1) - Seg::sum[1]; } for(int i = 1; i <= n; i++) printf("%lld ", ans[i]); return 0; }