zoukankan      html  css  js  c++  java
  • 2020牛客暑期多校(七) C

    2020牛客暑期多校(七) C - A National Pandemic(树链剖分)

    参考博客

    题意:

    一棵树支持3种操作:

    • 1 x w, 给x点加w,其它点y加 (w-dist(x, y)).
    • 2 x, 将x权值变为$min(0, f(x)) $;
    • 3 x, 查询x的权值(f(x))

    分析:

    先推荐一个题单: 树链剖分练习题 如果没有学过树链剖分可以做一下。

    首先2, 3操作用树链剖分处理都很直接,主要看1操作。给一个点x加w也还好处理,但给其他点加(w - dist(x, y)) 怎么加,难道要枚举点吗?显然点那么多会T。所以要处理这个操作可以观察这个式子可以写成 (w - dist(1, x) - dist(1, y) + 2*dist(1, lca(x,y))) 理解见下图,紫色是dist(1,x),绿色是dist(1, y) ,黄色是dist(1, lca(x,y))。

    (w - dist(1, x) - dist(1, y) + 2*dist(1, lca(x,y))​)

    观察式子可以看到(w-dist(1,x))(dist(1,y)) 都可以用变量去累计,因为对一个查询3操作,它前面的1操作时的(w-dist(1,x)) 你可以累计下来,然后减(dist(1, y)) 的个数就是前面1操作的个数,也可以用一个变量allnum记录树量。

    所以重点在处理(dist(1,lca(x,y))) 我们发现当查询一个点y时只要找到1到 y 路径上所以以前1操作标记的$lca(x,y) $点 ,求和这些点到 1 的距离即可,但这很麻烦不好处理。但是它是lca点到1的距离,所以我们可以在1处理时对1到x每个点权值+1,比如上图中处理x时,我把紫线上所有点+1,那么当处理2时我想要加的是1到lca的距离,可以发现此时1到lca的权值和就是1到lca的距离,这里用了差分的一个思想。当我们有很多x时,它们会在1到y条路径上1到某个点之间权值都加1,其实这个点就是lca,这个很好理解。所以我们只要用线段树维护权值和即可。但我们观察式子要2*dist(1,lca(x,y)).这只需要对每个1操作的x给线段树1到x之间的点+2即可。

    代码:

    #include<bits/stdc++.h>
    using namespace std;
    #define rep(i, a, n) for(int i = a; i <= n; ++ i);
    #define per(i, a, n) for(int i = n; i >= a; -- i);
    typedef long long ll;
    const int N = 50010;
    const ll mod = 1e9 + 7;
    const double Pi = acos(- 1.0);
    const int INF = 0x3f3f3f3f;
    const int G = 3, Gi = 332748118;
    ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; }
    ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
    ll lcm(ll a, ll b) { return a * b / gcd(a, b);}
    bool cmp(int a, int b){ return a > b;}
    //
    
    int T, n, m;
    int head[N << 1], cnt = 0;
    struct node{
        int to, nxt;
    }edge[N * 4];
    
    struct Tree{
        int l, r; int val, lz;
    }tree[N * 4];
    int del[N]; 
    int son[N], dfn[N], dep[N], top[N], fa[N], siz[N];
    int tot;
    
    void add(int u, int v){
        edge[cnt].to = v, edge[cnt].nxt = head[u], head[u] = cnt ++;
        edge[cnt].to = u, edge[cnt].nxt = head[v], head[v] = cnt ++;
    }
    
    void pushdown(int index){
        if(tree[index].lz){
            int temp = tree[index].lz;
            tree[index].lz = 0;
            tree[index << 1].val += (tree[index << 1].r - tree[index << 1].l + 1) * temp;
            tree[index << 1 | 1].val += (tree[index << 1 | 1].r - tree[index << 1 | 1].l + 1) * temp;
            tree[index << 1].lz += temp;
            tree[index << 1 | 1].lz += temp;
        }
    }
    
    void Build(int l, int r, int index){
        tree[index].l = l, tree[index].r = r;
        tree[index].lz = 0;
        if(l == r){
            tree[index].val = 0;
            return;
        }
        int mid = (tree[index].l + tree[index].r) >> 1;
        Build(l, mid, index << 1);
        Build(mid + 1, r, index << 1 | 1);
        tree[index].val = tree[index << 1].val + tree[index << 1 | 1].val;
    }
    
    void updata(int l, int r, int index, int val){
        if(tree[index].l >= l &&  tree[index].r <= r){
            tree[index].lz += val;
            tree[index].val += val * (tree[index].r - tree[index].l + 1);
            return;
        }
        if(tree[index].lz)  pushdown(index);
        int mid = (tree[index].l + tree[index].r) >> 1;
        if(l <= mid) updata(l, r, index << 1, val);
        if(r > mid) updata(l, r, index << 1 | 1, val);
        tree[index].val = tree[index << 1].val + tree[index << 1 | 1].val;
    }
    
    int query(int l, int r, int index){
        if(l <= tree[index].l && tree[index].r <= r){
            return tree[index].val;
        }
        if(tree[index].lz) pushdown(index);
        int mid = (tree[index].l + tree[index].r) >> 1;
        int ans = 0;
        if(l <= mid) ans += query(l, r, index << 1);
        if(r > mid) ans += query(l, r, index << 1 | 1);
        return ans;
    }
    // -------------------------------------
    
    void Csol(int x, int y){
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[y]]) swap(x, y);
            updata(dfn[top[x]], dfn[x], 1, 2);
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x, y);
        updata(dfn[x], dfn[y], 1, 2);
    }
    
    int Qsol(int x, int y){
        int ans = 0;
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[y]]) swap(x, y);
            ans += query(dfn[top[x]], dfn[x], 1);
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x ,y);
        ans += query(dfn[x], dfn[y], 1);
        return ans;
    }
    
    void dfs1(int u, int pre){
        dep[u] = dep[pre] + 1;
        fa[u] = pre;
        siz[u] = 1;
        int maxx = -1;
        for(int i = head[u]; i != -1; i = edge[i].nxt){
            int v = edge[i].to;
            if(v == pre) continue;
            dfs1(v, u);
            siz[u] += siz[v];
            if(siz[v] > maxx){
                maxx = siz[v];
                son[u] = v;
            }
        }
    }
    
    void dfs2(int u, int topu){ //topu当前链的最顶端的节点
        dfn[u] = ++ tot;
        top[u] = topu;
        if(!son[u]) return;
        dfs2(son[u], topu);
        for(int i = head[u]; i != -1; i = edge[i].nxt){
            int v = edge[i].to;
            if(v == son[u] || v == fa[u]) continue;
            dfs2(v, v);
        }
    }
    
    int main()
    {
        scanf("%d",&T);
        while(T --){
            scanf("%d%d",&n,&m);
            cnt = 0; tot = 0;
            for(int i = 1; i <= n; ++ i){
                head[i] = -1; del[i] = 0;
                son[i] = 0;
            }
            int x, y; 
            for(int i = 1; i < n; ++ i){
                scanf("%d%d",&x,&y);
                add(x, y);
            }
            dep[0] = 0;
            dfs1(1, 0);
            dfs2(1, 1);
            Build(1, n, 1);
            
            int op;
            int wval = 0, allnum = 0;
            while(m --){
                scanf("%d",&op);
                if(op == 1){
                    scanf("%d%d",&x,&y);
                    Csol(x, 1);
                    wval = wval + y - dep[x];
                    allnum ++;
                }
                else if(op == 2){
                    scanf("%d",&x);
                    int res = Qsol(x, 1) + wval - allnum * dep[x];
                    if(res > del[x]) del[x] = res;
                }
                else{
                    scanf("%d",&x);
                    int res = Qsol(x, 1) + wval - allnum * dep[x] - del[x];
                    printf("%d
    ",res);
                }
            }
        }
        return 0;
    }
    
  • 相关阅读:
    GitLab 介绍
    git 标签
    git 分支
    git 仓库 撤销提交 git reset and 查看本地历史操作 git reflog
    git 仓库 回退功能 git checkout
    python 并发编程 多进程 练习题
    git 命令 查看历史提交 git log
    git 命令 git diff 查看 Git 区域文件的具体改动
    POJ 2608
    POJ 2610
  • 原文地址:https://www.cnblogs.com/A-sc/p/13531750.html
Copyright © 2011-2022 走看看