zoukankan      html  css  js  c++  java
  • P6329 【模板】点分树 | 震波[点分树]

    点分树就是按照点分治的过程建出来,然后容斥一下.jpg

    // powered by c++11
    // by Isaunoya
    #include <bits/stdc++.h>
    
    #define rep(i, x, y) for (register int i = (x); i <= (y); ++i)
    #define Rep(i, x, y) for (register int i = (x); i >= (y); --i)
    
    using namespace std;
    using db = double;
    using ll = long long;
    using uint = unsigned int;
    using ull = unsigned long long;
    
    #define pii pair<int, int>
    #define fir first
    #define sec second
    
    template <class T>
    
    void cmax(T& x, const T& y) {
      if (x < y) x = y;
    }
    
    template <class T>
    
    void cmin(T& x, const T& y) {
      if (x > y) x = y;
    }
    
    #define all(v) v.begin(), v.end()
    #define sz(v) ((int)v.size())
    #define pb emplace_back
    
    template <class T>
    
    void sort(vector<T>& v) {
      sort(all(v));
    }
    
    template <class T>
    
    void reverse(vector<T>& v) {
      reverse(all(v));
    }
    
    template <class T>
    
    void unique(vector<T>& v) {
      sort(all(v)), v.erase(unique(all(v)), v.end());
    }
    
    void reverse(string& s) { reverse(s.begin(), s.end()); }
    
    const int io_size = 1 << 23 | 233;
    const int io_limit = 1 << 22;
    struct io_in {
      char ch;
    #ifndef __WIN64
      char getchar() {
        static char buf[io_size], *p1 = buf, *p2 = buf;
    
        return (p1 == p2) && (p2 = (p1 = buf) + fread(buf, 1, io_size, stdin), p1 == p2) ? EOF : *p1++;
      }
    #endif
      io_in& operator>>(char& c) {
        for (c = getchar(); isspace(c); c = getchar())
          ;
    
        return *this;
      }
      io_in& operator>>(string& s) {
        for (s.clear(); isspace(ch = getchar());)
          ;
    
        if (!~ch) return *this;
    
        for (s = ch; !isspace(ch = getchar()) && ~ch; s += ch)
          ;
    
        return *this;
      }
    
      io_in& operator>>(char* str) {
        char* cur = str;
        while (*cur) *cur++ = 0;
    
        for (cur = str; isspace(ch = getchar());)
          ;
        if (!~ch) return *this;
    
        for (*cur = ch; !isspace(ch = getchar()) && ~ch; *++cur = ch)
          ;
    
        return *++cur = 0, *this;
      }
    
      template <class T>
    
      void read(T& x) {
        bool f = 0;
        while ((ch = getchar()) < 48 && ~ch) f ^= (ch == 45);
    
        x = ~ch ? (ch ^ 48) : 0;
        while ((ch = getchar()) > 47) x = x * 10 + (ch ^ 48);
        x = f ? -x : x;
      }
    
      io_in& operator>>(int& x) { return read(x), *this; }
    
      io_in& operator>>(ll& x) { return read(x), *this; }
    
      io_in& operator>>(uint& x) { return read(x), *this; }
    
      io_in& operator>>(ull& x) { return read(x), *this; }
    
      io_in& operator>>(db& x) {
        read(x);
        bool f = x < 0;
        x = f ? -x : x;
        if (ch ^ '.') return *this;
    
        double d = 0.1;
        while ((ch = getchar()) > 47) x += d * (ch ^ 48), d *= .1;
        return x = f ? -x : x, *this;
      }
    } in;
    
    struct io_out {
      char buf[io_size], *s = buf;
      int pw[233], st[233];
    
      io_out() {
        set(7);
        rep(i, pw[0] = 1, 9) pw[i] = pw[i - 1] * 10;
      }
    
      ~io_out() { flush(); }
    
      void io_chk() {
        if (s - buf > io_limit) flush();
      }
    
      void flush() { fwrite(buf, 1, s - buf, stdout), fflush(stdout), s = buf; }
    
      io_out& operator<<(char c) { return *s++ = c, *this; }
    
      io_out& operator<<(string str) {
        for (char c : str) *s++ = c;
        return io_chk(), *this;
      }
    
      io_out& operator<<(char* str) {
        char* cur = str;
        while (*cur) *s++ = *cur++;
        return io_chk(), *this;
      }
    
      template <class T>
    
      void write(T x) {
        if (x < 0) *s++ = '-', x = -x;
    
        do {
          st[++st[0]] = x % 10, x /= 10;
        } while (x);
    
        while (st[0]) *s++ = st[st[0]--] ^ 48;
      }
    
      io_out& operator<<(int x) { return write(x), io_chk(), *this; }
    
      io_out& operator<<(ll x) { return write(x), io_chk(), *this; }
    
      io_out& operator<<(uint x) { return write(x), io_chk(), *this; }
    
      io_out& operator<<(ull x) { return write(x), io_chk(), *this; }
    
      int len, lft, rig;
    
      void set(int _length) { len = _length; }
    
      io_out& operator<<(db x) {
        bool f = x < 0;
        x = f ? -x : x, lft = x, rig = 1. * (x - lft) * pw[len];
        return write(f ? -lft : lft), *s++ = '.', write(rig), io_chk(), *this;
      }
    } out;
    
    const int maxn = 2e5 + 52;
    
    struct smt {
      int ls[maxn << 7], rs[maxn << 7], val[maxn << 7];
      int rt[maxn], frt[maxn], cnt;
    	// rt_i 这棵线段树存的是离 i 节点的距离/权值
    	// frt_i 这棵线段树存的是离 fa_i 这个节点的距离/权值
    	 
      smt() { cnt = 0; }
    
      void upd(int& p, int l, int r, int x, int v) {
        if (!p) p = ++cnt;
        val[p] += v;
        if (l == r) {
          return;
        }
        int mid = l + r >> 1;
        if (x <= mid) {
          upd(ls[p], l, mid, x, v);
        } else {
          upd(rs[p], mid + 1, r, x, v);
        }
      }
    
      int qry(int p, int a, int b, int l, int r) {
        if (!p) {
          return 0;
        }
        if (a <= l && r <= b) {
          return val[p];
        }
        int mid = l + r >> 1, ans = 0;
        if (a <= mid) {
          ans += qry(ls[p], a, b, l, mid);
        }
        if (b > mid) {
          ans += qry(rs[p], a, b, mid + 1, r);
        }
        return ans;
      }
    } smt;
    
    vector<int> g[maxn];
    pii st[maxn][22];
    int dfn[maxn], idx = 0, dep[maxn];
    void dfs(int u, int fa) {
      st[dfn[u] = ++idx][0] = pii(dep[u], u);
      for (int v : g[u])
        if (v ^ fa) {
          dep[v] = dep[u] + 1;
          dfs(v, u);
          st[++idx][0] = pii(dep[u], u);
        }
    }
    
    int mx[maxn], sz[maxn], vis[maxn], tot, rt;
    void getroot(int u, int fa) {
      sz[u] = 1, mx[u] = 0;
      for (int v : g[u])
        if (!vis[v] && v ^ fa) {
          getroot(v, u);
          sz[u] += sz[v];
          cmax(mx[u], sz[v]);
        }
      cmax(mx[u], tot - sz[u]);
      if (mx[u] < mx[rt]) rt = u;
    }
    
    int d[maxn], mxd;
    void getdis(int u, int fa) {
      cmax(mxd, d[u] = d[fa] + 1);
      for (int v : g[u])
        if (!vis[v] && v ^ fa) {
          getdis(v, u);
        }
    }
    
    int qwq[maxn], fa[maxn];
    void solve(int u) {
      vis[u] = 1, mxd = 0;
      getdis(u, 0), qwq[u] = mxd;
      for (int v : g[u]) {
        if (!vis[v]) {
          tot = sz[v], rt = 0;
          getroot(v, 0);
          fa[rt] = u, solve(rt);
        }
      }
    }
    
    int lg[maxn];
    int lca(int x, int y) {
      if ((x = dfn[x]) > (y = dfn[y])) {
        x ^= y ^= x ^= y;
      }
      int len = lg[y - x + 1];
      return min(st[x][len], st[y - (1 << len) + 1][len]).second;
    }
    int dis(int x, int y) { return dep[x] + dep[y] - (dep[lca(x, y)] << 1); }
    
    int n, m;
    int val[maxn];
    
    void change(int x, int v) {
      int now = x;
      while (now) {
        smt.upd(smt.rt[now], 0, qwq[now], dis(now, x), v);
        if (fa[now]) {
          smt.upd(smt.frt[now], 0, qwq[fa[now]], dis(fa[now], x), v);
        }
        now = fa[now];
      }
    }
    
    int qry(int x, int k) {
      int ans = 0, now = x;
      while (now) {
        if (dis(now, x) <= k) {
          ans += smt.qry(smt.rt[now], 0, k - dis(now, x), 0, qwq[now]);
          // 你只能加上离 now 不超过 k - dis(now,x) 的点 
        }
        if (fa[now] && dis(x, fa[now]) <= k) {
          ans -= smt.qry(smt.frt[now], 0, k - dis(fa[now], x), 0, qwq[fa[now]]);
          // 减掉离 fa[now] 不超过 k - dis(fa[now],x) 的点,做容斥
        }
        now = fa[now];
      }
      return ans;
    }
    
    signed main() {
      // code begin.
      in >> n >> m;
      rep(i, 1, n) { in >> val[i]; }
      rep(i, 2, n) {
        int u, v;
        in >> u >> v;
        g[u].pb(v), g[v].pb(u);
      }
      dfs(1, 0);
      rep(i, 2, idx) lg[i] = lg[i >> 1] + 1;
      rep(j, 1, lg[idx]) {
        rep(i, 1, idx - (1 << j) + 1) st[i][j] = min(st[i][j - 1], st[i + (1 << j - 1)][j - 1]);
      }
      mx[rt = 0] = 1e9, tot = n;
      getroot(1, 0), solve(rt);
      rep(i, 1, n) {
        int now = i;
        while (now) {
          smt.upd(smt.rt[now], 0, qwq[now], dis(now, i), val[i]); 
          if (fa[now]) {
            smt.upd(smt.frt[now], 0, qwq[fa[now]], dis(fa[now], i), val[i]);
          }
          now = fa[now];
        }
      }
      int ans = 0;
      while (m--) {
        int opt, x, y;
        in >> opt >> x >> y;
        x ^= ans, y ^= ans;
        if (!opt) {
          out << (ans = qry(x, y)) << '
    ';
        } else {
          change(x, y - val[x]), val[x] = y;
        }
      }
      return 0;
      // code end.
    }
    
  • 相关阅读:
    Flutter第一个应用--踩坑之路
    今天注册博客园了!
    广深小龙-基于unittest、pytest自动化测试框架之demo来学习啦!!!
    python接口自动化10-excel设计模式实战
    python接口自动化9-ddt数据驱动
    Docker学习4-学会如何让容器开机自启服务【坑】
    pytest-4-分布式运行与自定义顺序执行用例
    Docker学习10-docker-slenium进行web自动化测试
    linux+jenkins生成测试报告及任意IP打开链接能看到allure报告
    MySQL-Python实现-测试/生产环境各个表与字段进行对比的小工具
  • 原文地址:https://www.cnblogs.com/Isaunoya/p/12670817.html
Copyright © 2011-2022 走看看