zoukankan      html  css  js  c++  java
  • 51nod1819 黑白树V2

    简单的题面

    给定一棵以1为根的有根树,点可能是黑色或白色,操作如下。


    1. 选定一个点x,将x的子树中所有到x的距离为奇数的点的颜色反转。
    2. 选定一个点x,将点x的颜色反转。
    3. 选定一个点x,询问所有黑点y(包括点x)与点x的lca(最近公共祖先)的和。

    果然自己码一码收获挺大的....

    首先考虑怎么回答3操作,不妨考虑枚举$lca$

    如果$lca = x$,可以发现,在$x$子树内的答案都是$x$

    否则,根节点到$x$形成了一条链,令其为$1 o x_1 o x_2 ...... o x$

    可以发现,$x_i$对答案的贡献为$(sz[x_i] -  sz[x_{i + 1}]) * x_i$

    由于答案分布在一条链上,考虑使用轻重链剖分

    考虑到2操作和1操作

    用线段树动态的维护

    $sz[i][2][2], col[i]$

    分别表示

    1.$i$节点子树中,深度为奇 / 偶数,颜色为 黑 / 白的节点数

    2.$i$节点的颜色

    以及

    $sum[2][2]$表示区间内所有$g[i][2][2]$的和

    其中$g[i][x][y] = sz[i][x][y] * i$

    还有一大堆的细节,直接看代码吧,无法用言语来描述

    $51nod ;rank1$,$hhh$

    还是有$8kb$......应该还能短点的

    #include <cstdio>
    #include <cstring>
    #include <iostream>
    #include <algorithm>
    using namespace std;
    
    extern inline char gc() {
        static char RR[23456], *S = RR + 23333, *T = RR + 23333;
        if(S == T) fread(RR, 1, 23333, stdin), S = RR;
        return *S ++;
    }
    inline int read() {
        int p = 0, w = 1; char c = gc();
        while(c > '9' || c < '0') { if(c == '-') w = -1; c = gc(); }
        while(c >= '0' && c <= '9') p = p * 10 + c - '0', c = gc();
        return p * w;
    }
    
    int wr[50], rw;
    char WR[40000005], *I = WR;
    #define pc(z) *I ++ = (z)
    template <typename re>
    inline void write(re x) {
        if(!x) pc('0');
        if(x < 0) pc('-'), x = -x;
        while(x) wr[++ rw] = x % 10, x /= 10;
        while(rw) pc(wr[rw --] + '0'); pc('
    ');
    }
    
    #define fe float
    #define de double
    #define le long double
    #define ll long long
    #define ui unsigned int
    #define ri register int
    #define ull unsigned long long
    #define sid 200050
    #define eid 400050
    
    int n, m, cnp, did;
    int dfn[sid], ord[sid], anc[sid], col[sid];
    int cap[sid], node[eid], nxt[eid];
    int sz[sid], dep[sid], pre[sid], fa[sid];
    
    inline void adeg(int u, int v) {
        nxt[++ cnp] = cap[u]; cap[u] = cnp; node[cnp] = v;
    }
    
    #define cur node[i]
    inline void dfs(int o) {
        sz[o] = 1;
        for(int i = cap[o]; i; i = nxt[i])
        if(cur != fa[o]) {
            fa[cur] = o; dep[cur] = dep[o] + 1;
            dfs(cur); sz[o] += sz[cur]; 
            if(sz[pre[o]] < sz[cur]) pre[o] = cur;
        }
    }
    
    inline void dfs(int o, int tp) {
        anc[o] = tp; dfn[++ did] = o; ord[o] = did;
        if(pre[o]) dfs(pre[o], tp); else return;
        for(int i = cap[o]; i; i = nxt[i])
        if(cur != fa[o] && cur != pre[o]) dfs(cur, cur);
    }
    
    int f[sid][2][2], g[sid][2][2];
    inline void dp(int o) {
        f[o][dep[o] & 1][col[o]] = 1;
        for(int i = cap[o]; i; i = nxt[i])
        if(cur != fa[o]) {
            dp(cur);
            for(ri d = 0; d <= 1; d ++)
            for(ri c = 0; c <= 1; c ++)
            f[o][d][c] += f[cur][d][c];
        }
        for(ri d = 0; d <= 1; d ++)
        for(ri c = 0; c <= 1; c ++)
        g[o][d][c] = f[o][d][c] - f[pre[o]][d][c];
    }
    
    struct Seg {
        ll s[2][2];
        int rev[2], mas[2][2];
    } t[sid * 4];
    
    #define ls (o << 1)
    #define rs (o << 1 | 1)
    
    inline void update(int o) {
        for(ri i = 0; i <= 1; i ++)
        for(ri j = 0; j <= 1; j ++)
        t[o].s[i][j] = t[ls].s[i][j] + t[rs].s[i][j];
    }
    
    inline void build(int o, int l, int r) {
        if(l == r) {
            int x = dfn[l];
            for(ri i = 0; i <= 1; i ++)
            for(ri j = 0; j <= 1; j ++)
            t[o].s[i][j] = 1ll * g[x][i][j] * x;
            return;
        }
        int mid = (l + r) >> 1;
        build(ls, l, mid); build(rs, mid + 1, r);
        update(o);
    }
    
    inline void prev(int o, int d, int l, int r) {
        if(l == r) {
            int x = dfn[l];
            if((dep[x] & 1) == d) col[x] ^= 1;
            swap(f[x][d][0], f[x][d][1]);
        }
        swap(t[o].s[d][0], t[o].s[d][1]);
        swap(t[o].mas[d][0], t[o].mas[d][1]);
        t[o].rev[d] ^= 1;
    }
    
    inline void premas(int o, int d, int c, int v, int l, int r) {
        if(l == r) {
            int x = dfn[l];
            f[x][d][c] -= v; f[x][d][c ^ 1] += v;
        }
        t[o].mas[d][c] += v;
    }
    
    inline void pushdown(int o, int l, int r) {
        int mid = (l + r) >> 1;
        for(ri i = 0; i <= 1; i ++)
        if(t[o].rev[i]) {
            t[o].rev[i] = 0;
            prev(ls, i, l, mid); 
            prev(rs, i, mid + 1, r);
        }
        for(ri i = 0; i <= 1; i ++)
        for(ri j = 0; j <= 1; j ++)
        if(t[o].mas[i][j] != 0) {
            premas(ls, i, j, t[o].mas[i][j], l, mid);
            premas(rs, i, j, t[o].mas[i][j], mid + 1, r);
            t[o].mas[i][j] = 0;
        }
    }
    
    inline void rev(int o, int l, int r, int ml, int mr, int d) {
        if(ml > mr) return;
        if(ml > r || mr < l) return;
        if(ml <= l && mr >= r) { prev(o, d, l, r); return; }
        int mid = (l + r) >> 1;
        pushdown(o, l, r); 
        rev(ls, l, mid, ml, mr, d);
        rev(rs, mid + 1, r, ml, mr, d);
        update(o);
    }
    
    inline void mas(int o, int l, int r, int ml, int mr, int d, int c, int v) {
        if(ml > mr) return;
        if(ml > r || mr < l) return;
        if(ml <= l && mr >= r) { premas(o, d, c, v, l, r); return; }
        int mid = (l + r) >> 1;
        pushdown(o, l, r);
        mas(ls, l, mid, ml, mr, d, c, v);
        mas(rs, mid + 1, r, ml, mr, d, c, v);
        update(o);
    }
    
    inline void mis(int o, int l, int r, int p, int d, int c, int v) {
        if(l == r) { 
            f[p][d][c ^ 1] += v; f[p][d][c] -= v;
            t[o].s[d][c ^ 1] += 1ll * v * p; t[o].s[d][c] -= 1ll * v * p;
            return; 
        }
        int mid = (l + r) >> 1;
        pushdown(o, l, r);
        if(ord[p] <= mid) mis(ls, l, mid, p, d, c, v);
        else mis(rs, mid + 1, r, p, d, c, v);
        update(o);
    }
    
    inline int qc(int o, int l, int r, int v) {
        if(l == r) return col[v];
        int mid = (l + r) >> 1;
        pushdown(o, l, r);
        if(ord[v] <= mid) return qc(ls, l, mid, v);
        else return qc(rs, mid + 1, r, v);
    }
    
    inline int dsz(int o, int l, int r, int v, int d) {
        if(l == r) return f[v][d][1] - f[v][d][0];
        int mid = (l + r) >> 1;
        pushdown(o, l, r);
        if(ord[v] <= mid) return dsz(ls, l, mid, v, d);
        else return dsz(rs, mid + 1, r, v, d);
    }
    
    inline int qsz(int o, int l, int r, int v) {
        if(l == r) return f[v][0][1] + f[v][1][1];
        int mid = (l + r) >> 1;
        pushdown(o, l, r);
        if(ord[v] <= mid) return qsz(ls, l, mid, v);
        else return qsz(rs, mid + 1, r, v);
    }
    
    inline ll qs(int o, int l, int r, int ml, int mr) {
        if(ml > mr) return 0;
        if(ml > r || mr < l) return 0;
        if(ml <= l && mr >= r) return t[o].s[0][1] + t[o].s[1][1];
        int mid = (l + r) >> 1;
        pushdown(o, l, r);
        return qs(ls, l, mid, ml, mr) + qs(rs, mid + 1, r, ml, mr);
    }
    
    inline void change(int x) { 
        int d = (dep[x] + 1) & 1, ff = anc[x];
        int der = dsz(1, 1, n, x, d);
        rev(1, 1, n, ord[x], ord[x] + sz[x] - 1, d);
        mas(1, 1, n, ord[ff], ord[x] - 1, d, 1, der);
        for(ri i = anc[fa[ff]], j = fa[ff]; j; j = fa[i], i = anc[j])
        mis(1, 1, n, j, d, 1, der), mas(1, 1, n, ord[i], ord[j] - 1, d, 1, der);
    }
    
    inline void put(int x) {
        int ff = anc[x];
        col[x] = qc(1, 1, n, x);
        int d = dep[x] & 1, c = col[x];
        mis(1, 1, n, x, d, c, 1);
        mas(1, 1, n, ord[ff], ord[x] - 1, d, c, 1);
        for(ri i = anc[fa[ff]], j = fa[ff]; j; j = fa[i], i = anc[j])
        mis(1, 1, n, j, d, c, 1), mas(1, 1, n, ord[i], ord[j] - 1, d, c, 1);
        col[x] ^= 1;
    }
    
    inline ll query(int x) {
        ll ans = 0;
        int f = anc[x];
        ans += 1ll * qsz(1, 1, n, x) * x;
        for(ri i = f, j = x, o = x; j; j = fa[i], o = i, i = anc[j]) {
            if(j != o) ans += 1ll * (qsz(1, 1, n, j) - qsz(1, 1, n, o)) * j;
            ans += qs(1, 1, n, ord[i], ord[j] - 1);
        }
        return ans;
    }
    
    int main() {
        n = read(); m = read();
        for(ri i = 1; i <= n; i ++) col[i] = read();
        for(ri i = 1; i < n; i ++) {
            int u = read(), v = read();
            adeg(u, v); adeg(v, u);
        } 
        dfs(1); dfs(1, 1); 
        dp(1); build(1, 1, n);
        for(ri i = 1; i <= m; i ++) {
            int opt = read(), x = read();
            if(opt == 1) change(x);
            if(opt == 2) put(x);
            if(opt == 3) write(query(x));
        }
        fwrite(WR, 1, I - WR, stdout);
        return 0;
    }
  • 相关阅读:
    知道这几 个正则表达式,能让你少写 1,000 行代码
    移除手机端a标签点击自动出现的边框和背景
    CSS 元素垂直居中的 6种方法
    当文本超出时出现省略号
    css清除select的下拉箭头样式
    设置透明边框
    js 输出语句document.write()及动态改变元素中内容innerHTML的使用
    LOCAL_EXPORT_××用法
    sprd测试系统跑vts
    C++ const用法
  • 原文地址:https://www.cnblogs.com/reverymoon/p/9486271.html
Copyright © 2011-2022 走看看