zoukankan      html  css  js  c++  java
  • bzoj2243: [SDOI2011]染色(树链剖分)

    bzoj2243

    树链剖分好题啊!

    题目描述:给定一颗n个点的树,有m个操作,操作有两种。

                     1、将节点a到节点b路径上所有的点都染成颜色c。

                     2、询问节点a到节点b路径上的颜色段数量(连续的被认为是同一段)。

    输入格式:第一行包含两个整数n和m,表示节点数和操作个数。

                     第二行n个整数,表示每个节点的初始颜色。

                     接下来n - 1行,每行两个整数描述一棵树。

                     接下来m行,每行表示一个操作。

    输出格式:对于每一个询问颜色段数量的操作,输出一行一个整数,表示颜色段的数量。

    输入样例:

    6 5
    2 2 1 2 1 1
    1 2
    1 3
    2 4
    2 5
    2 6
    Q 3 5
    C 2 1 1
    Q 3 5
    C 5 1 2
    Q 3 5

    输出样例:

    3
    1
    2

    解析:很显然是树剖,问题是如何在线段树上维护不同颜色的段数。

              用sum[o]表示一段的不同颜色的段数,lf[o]表示这一段最左边的颜色,rt[o]表示这一段最右边的颜色。

              那么在进行合并时,lf[o] = lf[o << 1],rt[o] = rt[o << 1 | 1]。需要注意的是sum[o]的维护,若左端点的右端颜色等于右端点的左端颜色,则sum[o]要减1。

              即若颜色相同,sum[o] = sum[o << 1] + sum[o << 1 | 1] - 1;若颜色不同,sum[o] = sum[o << 1] + sum[o << 1 | 1]

              另一个需要思考的地方便是如何计算答案,由于在树上剖链时剖出的链是不连续的,所以不能单纯进行累加。

              这时就要用到 lca 了,可以求出两个点的lca,分别对两段进行累加,这样答案就可以计算了。

              由于在树剖时是从深度大的往深度小的剖,所以在线段树中较右的节点会先别访问到,所以可以记录一个last,表示上一次剖到的左端点颜色是last,这样就可以将答案累加。

              有很多细节需要注意,细节可以看代码。

    代码如下:

      1 #include<cstdio>
      2 #include<vector>
      3 #include<algorithm>
      4 #include<cstring>
      5 #define lc o << 1
      6 #define rc o << 1 | 1
      7 using namespace std;
      8 
      9 const int maxn = 1e5 + 5;
     10 int n, m, col[maxn], bj[maxn * 4], sum[maxn * 4], lf[maxn * 4], rt[maxn * 4], ans, last;
     11 int dep[maxn], fa[maxn], size[maxn], heavy[maxn], seq[maxn], dfn[maxn], top[maxn], cnt;
     12 char s[5];
     13 vector <int> ve[maxn];
     14 
     15 int read(void) {
     16     char c; while (c = getchar(), c < '0' || c >'9'); int x = c - '0';
     17     while (c = getchar(), c >= '0' && c <= '9') x = x * 10 + c - '0'; return x;
     18 }
     19 
     20 void dfs1(int u, int pre) {
     21     dep[u] = dep[pre] + 1;
     22     fa[u] = pre; size[u] = 1;
     23       for (int i = 0; i < ve[u].size(); ++ i) {
     24           int v = ve[u][i];
     25             if (v == pre) continue;
     26           dfs1(v, u);
     27           size[u] += size[v];
     28             if (size[v] > size[heavy[u]]) heavy[u] = v;
     29       }
     30 }
     31 
     32 void dfs2(int u, int cur) {
     33     dfn[u] = ++ cnt; seq[cnt] = u;
     34     top[u] = cur;
     35       if (!heavy[u]) return;
     36     dfs2(heavy[u], cur);
     37       for (int i = 0; i < ve[u].size(); ++ i) {
     38           int v = ve[u][i];
     39             if (v == fa[u] || v == heavy[u]) continue;
     40           dfs2(v, v);
     41       }
     42 }
     43 
     44 void maintain(int o) { //维护每段的信息 
     45     lf[o] = lf[lc]; rt[o] = rt[rc];
     46     if (rt[lc] == lf[rc]) sum[o] = sum[lc] + sum[rc] - 1;
     47       else sum[o] = sum[lc] + sum[rc];
     48 }
     49 
     50 void pushdown(int o) { //标记下放 
     51     sum[lc] = sum[rc] = 1;
     52     lf[lc] = lf[rc] = rt[lc] = rt[rc] = bj[o];
     53     bj[lc] = bj[rc] = bj[o]; bj[o] = -1; 
     54 }
     55 
     56 void build(int o, int l, int r) { //建树 
     57     if (l == r) {
     58       lf[o] = col[seq[l]];
     59       rt[o] = col[seq[l]];
     60       sum[o] = 1; 
     61       return;
     62     }
     63     int mid = l + r >> 1;
     64     build(lc, l, mid); build(rc, mid + 1, r);
     65     maintain(o);
     66 }
     67 
     68 void modify(int o, int l, int r, int ql, int qr, int c) { //区间修改 
     69     if (ql <= l && qr >= r) {
     70       lf[o] = rt[o] = c; 
     71       sum[o] = 1; bj[o] = c;
     72       return;
     73     }
     74     int mid = l + r >> 1;
     75     if (bj[o] != -1) pushdown(o);
     76     if (ql <= mid) modify(lc, l, mid, ql, qr, c);
     77     if (qr > mid) modify(rc, mid + 1, r, ql, qr, c);
     78     maintain(o);
     79 }
     80 
     81 void query(int o, int l, int r, int ql, int qr) {
     82     if (ql <= l && qr >= r) {
     83       if (rt[o] == last) ans += sum[o] - 1; //如果右端的颜色和上一个左端相同,就-1 
     84         else ans += sum[o];  
     85       last = lf[o]; //更新last表示的左端点 
     86       return;
     87     }
     88     int mid = l + r >> 1;
     89     if (bj[o] != -1) pushdown(o);
     90     if (qr > mid) query(rc, mid + 1, r, ql, qr); //由于是从右向左更新答案,所以线段树上询问时也要优先向右询问! 
     91     if (ql <= mid) query(lc, l, mid, ql, qr);
     92 }
     93 
     94 void chain_modify(int x, int y, int c) { //树上修改 
     95     int fax = top[x], fay = top[y];
     96       while (fax != fay) {
     97           if (dep[fax] < dep[fay]) {
     98               swap(fax, fay);
     99               swap(x, y);
    100           }
    101         modify(1, 1, n, dfn[fax], dfn[x], c);
    102         x = fa[fax];
    103         fax = top[x];
    104       }
    105       if (dep[x] > dep[y]) swap(x, y);
    106     modify(1, 1, n, dfn[x], dfn[y], c);
    107 }
    108  
    109 void chain_query(int x, int y) { //树上询问 
    110     int fax = top[x], fay = top[y];
    111       while (fax != fay) {
    112           if (dep[fax] < dep[fay]) {
    113               swap(fax, fay);
    114               swap(x, y);
    115           }
    116         query(1, 1, n, dfn[fax], dfn[x]);
    117         x = fa[fax];
    118         fax = top[x];
    119       }
    120       if (dep[x] > dep[y]) swap(x, y);
    121     query(1, 1, n, dfn[x], dfn[y]);
    122 }
    123 
    124 int getlca(int x, int y) { //求lca 
    125     int fax = top[x], fay = top[y];
    126       while (fax != fay) {
    127           if (dep[fax] < dep[fay]) {
    128               swap(fax, fay);
    129               swap(x, y);
    130           }
    131         x = fa[fax];
    132         fax = top[x];
    133       }
    134     if (dep[x] > dep[y]) swap(x, y);
    135     return x;
    136 }
    137 
    138 int main() {
    139     n = read(); m = read();
    140       for (int i = 1; i <= n; ++ i) col[i] = read();
    141       for (int i = 1; i < n; ++ i) {
    142           int x = read(), y = read();
    143           ve[x].push_back(y);
    144           ve[y].push_back(x);
    145       }
    146     dfs1(1, 0);
    147     dfs2(1, 1);
    148     build(1, 1, n);
    149     memset(bj, -1, sizeof(bj)); //颜色可以为0!所以初始标记是-1 
    150       while (m --) {
    151           scanf("%s", s + 1); 
    152             if (s[1] == 'C') {
    153                   int x = read(), y = read(), c = read();
    154                   int lca = getlca(x, y);
    155                   chain_modify(x, lca, c); chain_modify(lca, y, c);
    156             }
    157           else { //要求两次答案,并累加答案 
    158               int x = read(), y = read(); ans = 0; last = -1;
    159               int lca = getlca(x, y);
    160               chain_query(x, lca);
    161               last = -1; 
    162               chain_query(lca, y);
    163               printf("%d
    ", ans - 1); //这里ans必须-1,因为lca处的颜色必定相同 
    164           }
    165       }
    166     return 0;
    167 } 
  • 相关阅读:
    理解 Redis(3)
    理解 Redis(2)
    理解 Redis(1)
    git 的基本命令
    使用python实现计算器功能
    python函数说明内容格式错误
    python的小基础
    python去除读取文件中多余的空行
    数论-下属不可以和上司顶嘴!(可能是总结)
    其他-一大堆记录 (20 Dec
  • 原文地址:https://www.cnblogs.com/Gaxc/p/9928321.html
Copyright © 2011-2022 走看看