zoukankan      html  css  js  c++  java
  • SegmentTreeBeats 简单学习笔记

    SegmentTreeBeats 简单学习笔记

    ​ 有一天补 ( ext{CF}) 做到一个题,转化一波题意以后变成要求维护一个序列 (a)

    1. 对于 (i in [l,r], a_i =a_i+x)

    2. 对于 (i in [l,r], a_i =min(a_i, x))

    3. (sum_{i=l}^r a_i)

      ​ 其实就是 ( ext{Segment Tree Beats}) 的模板题,也就是那年吉老师营员交流课件的例题, 用线段树维护区间最大值 (mx) ,区间次大值 (se) ,区间和 (sum) ,区间最大值出现次数 (cnt) ,加法标记 (tag)

      ​ 对于第二种操作,如果一个区间 (mx leq x) 那么无事发生,可以跳过其所有子区间。如果 (se < x < mx) ,那么 (sum = sum - (mx-x) imes cnt, mx = x) ,注意这里子区间的 (mx, sum) 并没有更改,相当于 (mx) 同时作为一个修改标记,当前区间比子区间的 (mx) 小时,要进行 (sum = sum - (mx-mx[fa]) imes cnt)( ext{pushdown}) 操作,打完标记之后就可以跳过了。对于其它情况,暴力对其子区间求解。

      ​ 到这一步位置算法流程不难理解,但算法的复杂度证明比较难懂,目前 (mathcal O(nlog n)) 的证明我还不会,只能理解 (mathcal O(n log^2 n)) 的证明,在这里写一下简要证明:

      ​ 定义势能函数 (Phi) 为线段树中 (mx) 不等于其父亲节点 (mx) 的节点数量,考虑一次第二操作过程的任一终止节点 (v) 。如果 (v)(Phi) 有贡献,假设这一类节点的数量为 (A) ,到达这些节点的复杂度为 (mathcal O (Alog n)) ,结束后这些节点都对势能没贡献了,也就是说用了 (mathcal O(Alog n)) 的时间让势能减小了 (A)

      ​ 如果 (v)(Phi) 没贡献,记 (u)(v) 的父亲,(u) 的另外一儿子为 (c) ,那么 (mx[u] = mx[v], se[u] eq se[v]) ,也就是说 (se[u] = mx[c]) 。那么 (c) 的子树一定会被访问, 并在访问结束后 (c)(Phi) 没有贡献,假设这一类节点数量为 (A) ,同样也用 (mathcal O(Alog n)) 的时间让势能减小了 (A) 。也就是说对于修改操作,实际上是每减小一个势能用了 (mathcal O(log n)) 的代价。

      ​ 考虑修改操作,每次只会修改 (mathcal O(log n)) 节点,最多使势能增加 (mathcal O(log n)) 所以总复杂度是 (mathcal O(nlog^2 n))

      code: Codeforces 1290 E

      /*program by mangoyang*/
      #pragma GCC optimize("Ofast", "inline")
      #include<bits/stdc++.h>
      #define inf (0x7f7f7f7f)
      #define Max(a, b) ((a) > (b) ? (a) : (b))
      #define Min(a, b) ((a) < (b) ? (a) : (b))
      typedef long long ll;
      using namespace std;
      template <class T>
      inline void read(T &x){
          int ch = 0, f = 0; x = 0;
          for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
          for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
          if(f) x = -x;
      }
      #define int ll
      const int N = 150005;
      int a[N], b[N], ans[N], n;
      namespace Seg{
          #define lson (u << 1)
          #define rson (u << 1 | 1)
          #define mid ((l + r) >> 1)
          int mx[N<<2], se[N<<2], sz[N<<2], cnt[N<<2], sum[N<<2], tag[N<<2];
          inline void clear(){
              memset(mx, 0, sizeof(mx));
              memset(se, 0, sizeof(se));
              memset(sz, 0, sizeof(sz));
              memset(cnt, 0, sizeof(cnt));
              memset(sum, 0, sizeof(sum));
              memset(tag, 0, sizeof(tag));
          }
          inline void update(int u){
              if(mx[lson] > mx[rson])
                  mx[u] = mx[lson], cnt[u] = cnt[lson];
              else
                  mx[u] = mx[rson], cnt[u] = cnt[rson];
              if(mx[lson] == mx[rson]) cnt[u] += cnt[lson];
              se[u] = max(se[lson], se[rson]);
              if(mx[lson] != mx[rson]){
                  int x = min(mx[lson], mx[rson]);
                  se[u] = max(se[u], x);
              }
              sum[u] = sum[lson] + sum[rson];
              sz[u] = sz[lson] + sz[rson];
          }
          inline void pushdown(int u){
              if(tag[u]){
                  if(mx[lson]) mx[lson] += tag[u];
                  if(se[lson]) se[lson] += tag[u];
                  if(mx[rson]) mx[rson] += tag[u];
                  if(se[rson]) se[rson] += tag[u];
                  sum[lson] += tag[u] * sz[lson];
                  sum[rson] += tag[u] * sz[rson];
                  tag[lson] += tag[u];
                  tag[rson] += tag[u];
                  tag[u] = 0;
              }
              if(mx[lson] > mx[u]){
                  sum[lson] -= (mx[lson] - mx[u]) * cnt[lson];
                  mx[lson] = mx[u];
              }
              if(mx[rson] > mx[u]){
                  sum[rson] -= (mx[rson] - mx[u]) * cnt[rson];
                  mx[rson] = mx[u];
              }
          }
          inline void ins(int u, int l, int r, int pos, int x){
              if(l == r){
                  mx[u] = sum[u] = x;
                  sz[u] = cnt[u] = 1;
                  return;
              }
              pushdown(u);
              if(pos <= mid) ins(lson, l, mid, pos, x);
              else ins(rson, mid + 1, r, pos, x);
              update(u);
          }
          inline void gao(int u, int l, int r, int L, int R, int x){
              if(l >= L && r <= R){
                  if(mx[u] <= x) return;
                  if(se[u] < x){
                      sum[u] -= (mx[u] - x) * cnt[u];
                      mx[u] = x;
                      return;
                  }
                  pushdown(u);
                  gao(lson, l, mid, L, R, x);
                  gao(rson, mid + 1, r, L, R, x);
                  update(u);
                  return;
              }
              pushdown(u);
              if(L <= mid) gao(lson, l, mid, L, R, x);
              if(mid < R) gao(rson, mid + 1, r, L, R, x);
              update(u);
          }
          inline void add(int u, int l, int r, int L, int R){
              if(l >= L && r <= R){
                  if(mx[u]) mx[u]++;
                  if(se[u]) se[u]++;
                  sum[u] += sz[u], tag[u]++;
                  return;
              }
              pushdown(u);
              if(L <= mid) add(lson, l, mid, L, R);
              if(mid < R) add(rson, mid + 1, r, L, R);
              update(u);
          }
          inline int query(int u, int l, int r, int L, int R){
              if(l >= L && r <= R) return sz[u];
              int res = 0; pushdown(u);
              if(L <= mid) res += query(lson, l, mid, L, R);
              if(mid < R) res += query(rson, mid + 1, r, L, R);
              return res;
          }
      }
      signed main(){
          read(n);
          for(int i = 1; i <= n; i++) read(a[i]);
          for(int i = 1; i <= n; i++) b[a[i]] = i;
          for(int i = 1; i <= n; i++){
              Seg::add(1, 1, n, b[i] + 1, n);
              int sz = Seg::query(1, 1, n, 1, b[i]);
              if(sz) Seg::gao(1, 1, n, 1, b[i], sz);
              Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1);
              ans[i] = Seg::sum[1] + Seg::sz[1];
          }
          reverse(a + 1, a + n + 1);
          for(int i = 1; i <= n; i++) b[a[i]] = i;
          Seg::clear();
              for(int i = 1; i <= n; i++){
              Seg::add(1, 1, n, b[i] + 1, n);
              int sz = Seg::query(1, 1, n, 1, b[i]);
              if(sz) Seg::gao(1, 1, n, 1, b[i], sz);
              Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1);
              ans[i] -= Seg::sz[1] * (Seg::sz[1] + 1) - Seg::sum[1];
          }
          for(int i = 1; i <= n; i++)
              printf("%lld
      ", ans[i]);
          return 0;
      }
      
  • 相关阅读:
    实验吧_简单的sql注入_1、2、3
    实验吧_天下武功唯快不破&让我进去(哈希长度拓展攻击)
    实验吧_密码忘记了(vim编辑器+代码审计)&天网管理系统(php弱比较+反序列化)
    实验吧_Guess Next Session&Once More(代码审计)
    实验吧_NSCTF web200&FALSE(代码审计)
    实验吧_程序逻辑问题(代码审计)&上传绕过
    实验吧_貌似有点难(php代码审计)&头有点大
    网络安全实验室_上传关writeup
    php文件包含漏洞(input与filter)
    我为什么要写LeetCode的博客?
  • 原文地址:https://www.cnblogs.com/mangoyang/p/12567727.html
Copyright © 2011-2022 走看看