zoukankan      html  css  js  c++  java
  • 模板

    https://www.luogu.org/problem/P3384

    如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
    操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
    操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
    操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
    操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

    小心不要抄错。这个常数比较小。

    #include<bits/stdc++.h>
    #define lc (o<<1)
    #define rc (o<<1|1)
    typedef long long ll;
    using namespace std;
    
    const int MAXN = 100000 + 5;
    int dep[MAXN], siz[MAXN],  son[MAXN], fa[MAXN], top[MAXN], tid[MAXN], rnk[MAXN], cnt;
    
    int n, m, r, mod;
    int a[MAXN];
    
    int head[MAXN], etop;
    
    struct Edge {
        int v, next;
    } e[MAXN * 2];
    
    inline void init(int n) {
        etop = 0;
        memset(head, -1, sizeof(head[0]) * (n + 1));
    }
    
    inline void addedge(int u, int v) {
        e[++etop].v = v;
        e[etop].next = head[u];
        head[u] = etop;
        e[++etop].v = u;
        e[etop].next = head[v];
        head[v] = etop;
    }
    
    struct SegmentTree {
        int sum[MAXN * 4], lz[MAXN * 4];
        void pushup(int o) {
            sum[o] = (sum[lc] + sum[rc]) % mod;
        }
        void pushdown(int o, int l, int r) {
            if(lz[o]) {
                lz[lc] = (lz[lc] + lz[o]) % mod;
                lz[rc] = (lz[rc] + lz[o]) % mod;
                int m = l + r >> 1;
                sum[lc] = (1ll * lz[o] * (m - l + 1) + sum[lc]) % mod;
                sum[rc] = (1ll * lz[o] * (r - m) + sum[rc]) % mod;
                lz[o] = 0;
            }
        }
    
        void build(int o, int l, int r) {
            if (l == r)
                sum[o] = a[rnk[l]] % mod;
            else {
                int m = (l + r) >> 1;
                build(lc, l, m);
                build(rc, m + 1, r);
                pushup(o);
            }
            lz[o] = 0;
        }
    
        void update(int o, int l, int r, int ql, int qr, int v) {
            if (ql <= l && r <= qr) {
                lz[o] = (lz[o] + v) % mod;
                sum[o] = (sum[o] + 1ll * v * (r - l + 1)) % mod;
            } else {
                pushdown(o, l, r);
                int m = (l + r) >> 1;
                if (ql <= m)
                    update(lc, l, m, ql, qr, v);
                if (qr >= m + 1)
                    update(rc, m + 1, r, ql, qr, v);
                pushup(o);
            }
        }
    
        int query(int o, int l, int r, int ql, int qr) {
            if (ql <= l && r <= qr) {
                return sum[o];
            } else {
                pushdown(o, l, r);
                int m = (l + r) >> 1;
                int res = 0;
                if (ql <= m)
                    res += query(lc, l, m, ql, qr);
                if (qr >= m + 1)
                    res += query(rc, m + 1, r, ql, qr);
                return res % mod;
            }
        }
    } st;
    
    void init1() {
        dep[r] = 1;
    }
    
    void dfs1(int u, int t) {
        siz[u] = 1, son[u] = -1, fa[u] = t;
        for (int i = head[u]; i != -1; i = e[i].next) {
            int v = e[i].v;
            if(v == t)
                continue;
            dep[v] = dep[u] + 1;
            dfs1(v, u);
            siz[u] += siz[v];
            if (son[u] == -1 || siz[v] > siz[son[u]])
                son[u] = v;
        }
    }
    
    void init2() {
        cnt = 0;
    }
    
    void dfs2(int u, int t) {
        top[u] = t;
        tid[u] = ++cnt;
        rnk[cnt] = u;
        if (son[u] == -1)
            return;
        dfs2(son[u], t);
        for (int i = head[u]; i != -1; i = e[i].next) {
            int v = e[i].v;
            if(v==fa[u]||v==son[u])
                continue;
            dfs2(v, v);
        }
    }
    
    int query1(int u, int v) {
        ll ret = 0;
        int tu = top[u], tv = top[v];
        while (tu != tv) {
            if (dep[tu] >= dep[tv]) {
                ret += st.query(1, 1, n, tid[tu], tid[u]);
                u = fa[tu];
                tu = top[u];
            } else {
                ret += st.query(1, 1, n, tid[tv], tid[v]);
                v = fa[tv];
                tv = top[v];
            }
        }
        if(tid[u] <= tid[v])
            ret += st.query(1, 1, n, tid[u], tid[v]);
        else
            ret += st.query(1, 1, n, tid[v], tid[u]);
        return ret % mod;
    }
    
    inline int query2(int u) {
        return st.query(1, 1, n, tid[u], tid[u] + siz[u] - 1);
    }
    
    inline void update1(int u, int v, int val) {
        val %= mod;
        int tu = top[u], tv = top[v];
        while (tu != tv) {
            if (dep[tu] >= dep[tv]) {
                st.update(1, 1, n, tid[tu], tid[u], val);
                u = fa[tu];
                tu = top[u];
            } else {
                st.update(1, 1, n, tid[tv], tid[v], val);
                v = fa[tv];
                tv = top[v];
            }
        }
        if(tid[u] <= tid[v])
            st.update(1, 1, n, tid[u], tid[v], val);
        else
            st.update(1, 1, n, tid[v], tid[u], val);
    }
    
    inline void update2(int u, int val) {
        val %= mod;
        st.update(1, 1, n, tid[u], tid[u] + siz[u] - 1, val);
    }
    
    void op1() {
        int u, v, val;
        scanf("%d%d%d", &u, &v, &val);
        update1(u, v, val);
    }
    
    void op2() {
        int u, v;
        scanf("%d%d", &u, &v);
        printf("%d
    ", query1(u, v) % mod);
    }
    
    void op3() {
        int u, val;
        scanf("%d%d", &u, &val);
        update2(u, val);
    }
    
    void op4() {
        int u;
        scanf("%d", &u);
        printf("%d
    ", query2(u) % mod);
    }
    
    int main() {
    #ifdef Yinku
        freopen("Yinku.in", "r", stdin);
    #endif // Yinku
        scanf("%d%d%d%d", &n, &m, &r, &mod);
        for(int i = 1; i <= n; ++i) {
            scanf("%d", &a[i]);
        }
        init(n);
        for(int i = 1, u, v; i <= n - 1; ++i) {
            scanf("%d%d", &u, &v);
            addedge(u, v);
        }
        init1();
        dfs1(r, -1);
        init2();
        dfs2(r, r);
        st.build(1, 1, n);
        for(int i = 1, op; i <= m; ++i) {
            scanf("%d", &op);
            switch(op) {
            case 1:
                op1();
                break;
            case 2:
                op2();
                break;
            case 3:
                op3();
                break;
            case 4:
                op4();
                break;
            }
        }
        return 0;
    }
    
  • 相关阅读:
    IP掩码的作用
    linux shell 笔记
    ubuntu apt-get Failed to fetch Temporary failure resolving 'security.ubuntu.com'
    ubuntu 16.04 & 18.04 远程桌面使用
    取消Ubuntu开机硬盘自检
    linux shell 脚本输入参数解析
    Ubuntu 16.04 + python3 源码 安装+使用labelImg最新版
    用tinyxml2读写xml文件_C++实现
    常用工具问题及解决方案
    可视化调试工具
  • 原文地址:https://www.cnblogs.com/Yinku/p/11309586.html
Copyright © 2011-2022 走看看