zoukankan      html  css  js  c++  java
  • 树链剖分

    POJ 3237 Tree

    学了一下树链剖分。就是把树剖成链,然后用线段树、树状数组、splay等数据结构来维护。

    // POJ 3237 TREE
    
    
    /**DESC: 给出一棵树,有三种操作:
          1:第i条边的权值修改成v.
          2:a 到 b 的路径上的权值全都取反。
          3:在 a 到 b的路径上的权值找最大。
    */
    
    /** 思路:线段树维护树链剖分。
     * 道理我都懂,就是代码麻烦了有点。T_T
     */
    
    #include <stdio.h>
    #include <string.h>
    #include <iostream>
    #include <vector>
    #define maxn 10010
    using namespace std;
    
    struct Edge{
        int u, v;
        int nxt;
    }edge[maxn*2];
    
    int head[maxn];
    int tot;
    
    void addEdge(int u, int v) {
        edge[tot].u = u;
        edge[tot].v = v;
        edge[tot].nxt = head[u];
        head[u] = tot++;
    }
    
    //树链剖分 把树剖成链
    int fa[maxn]; //dfs1
    int deep[maxn];
    int num[maxn];
    int son[maxn];
    
    int top[maxn]; //dfs2
    int p[maxn];
    int fp[maxn];
    int pos;
    
    void dfs1(int u, int pre, int d) {
        deep[u] = d;
        fa[u] = pre;
        num[u] = 1;
        for (int i=head[u]; i!=-1; i=edge[i].nxt) {
            int v = edge[i].v;
            if (v != pre) {
                dfs1(v, u, d+1);
                num[u] += num[v];
                if (son[u] == -1 || num[v] > num[son[u]]) {
                    son[u] = v;
                }
            }
        }
    }
    
    void dfs2(int u, int sp) {
        top[u] = sp;
        p[u] = pos++; /// u和父亲结点的边在线段树中的位置
        fp[p[u]] = u; /// 和fa[]数组相反,线段树中的第fp[u]条边是原树中u点和父亲的连边
        if (son[u] == -1) return;
        dfs2(son[u], sp);
        for (int i=head[u]; i!=-1; i=edge[i].nxt) {
            int v = edge[i].v;
            if (v != son[u] && v != fa[u]) {
                dfs2(v, v);
            }
        }
    }
    
    //线段树
    struct Node {
        int l, r;
        int maxm;
        int minn;
        int ne;
    }segTree[maxn*4];
    
    void build(int rt, int l, int r) {
        segTree[rt].l = l;
        segTree[rt].r = r;
        segTree[rt].maxm = 0;
        segTree[rt].minn = 0;
        if (l == r) return;
        int mid = ((l+r)>>1);
        build(rt<<1, l, mid);
        build((rt<<1)|1, mid+1, r);
    }
    
    void push_down(int i) {
        if (segTree[i].l == segTree[i].r) return;
        if (segTree[i].ne) {
            segTree[i<<1].maxm = -segTree[i<<1].maxm;
            segTree[i<<1].minn = -segTree[i<<1].minn;
            swap(segTree[i<<1].maxm, segTree[i<<1].minn);
            segTree[(i<<1)|1].maxm = -segTree[(i<<1|1)].maxm;
            segTree[(i<<1)|1].minn = -segTree[(i<<1|1)].minn;
            swap(segTree[(i<<1)|1].maxm, segTree[(i<<1)|1].minn);
    
            segTree[i<<1].ne ^= 1; //左右子结点的延迟标记更新
            segTree[(i<<1)|1].ne ^= 1;
            segTree[i].ne = 0; ////
        }
    }
    
    void push_up(int i) {
        segTree[i].maxm = max(segTree[i<<1].maxm, segTree[(i<<1)|1].maxm);
        segTree[i].minn = min(segTree[i<<1].minn, segTree[(i<<1)|1].minn);
    }
    
    void update(int i, int k, int val) {
        if (segTree[i].l == k && segTree[i].r == k) {
            segTree[i].maxm = val;
            segTree[i].minn = val;
            segTree[i].ne = 0;///
            return;
        }
        push_down(i); //向下延迟标记
        int mid = (segTree[i].l + segTree[i].r) / 2;
        if (k <= mid) update(i<<1, k, val);
        else update((i<<1)|1, k, val);
        push_up(i); //向上延迟标记
    }
    
    void init() {
        memset(head, -1, sizeof(head));
        tot = 0;
        pos = 0;
        memset(son, -1, sizeof(son));
    }
    
    int e[maxn][3];
    
    void ne_update(int rt, int l, int r) { //把线段树的[l, r]区间取反
        if (segTree[rt].l == l && segTree[rt].r == r) {
            segTree[rt].maxm = -segTree[rt].maxm;
            segTree[rt].minn = -segTree[rt].minn;
            swap(segTree[rt].maxm, segTree[rt].minn);
            segTree[rt].ne ^= 1; // 延迟标记
            return;
        }
        push_down(rt); ///
        int mid = (segTree[rt].l + segTree[rt].r) / 2;
        if (r <= mid) { //全都在左区间
            ne_update(rt<<1, l, r);
        }else if (l > mid) {
            ne_update((rt<<1)|1, l, r);
        }else {
            ne_update(rt<<1, l, mid);
            ne_update((rt<<1)|1, mid+1, r);
        }
        push_up(rt); ///
    }
    
    
    void Negate(int u, int v) {
        int f1 = top[u], f2 = top[v];
        while(f1 != f2) {
            if (deep[f1] < deep[f2]) { //使得depp[f1] > deep[f2]
                swap(f1, f2);
                swap(u, v);
            }
            ne_update(1, p[f1], p[u]); ///
            u = fa[f1], f1 = top[u];
        }
        if (u == v) return;
        if (deep[u] > deep[v]) swap(u, v); //使得deep[u] < deep[v]
        ne_update(1, p[son[u]], p[v]);
    }
    
    int query(int rt, int l, int r) { // 查询线段树中[l, r] 的最大值
        if (segTree[rt].l == l && segTree[rt].r == r)
            return segTree[rt].maxm;
        push_down(rt); ///
        int mid = (segTree[rt].l + segTree[rt].r) / 2;
        if (r <= mid) {
            return query(rt<<1, l, r);
        }else if (l > mid) {
            return query((rt<<1)|1, l, r);
        }else {
            return max(query(rt<<1, l, mid), query((rt<<1)|1, mid+1, r));
        }
        push_up(rt); ///
    }
    
    
    int findMax(int u, int v) {
        int f1 = top[u], f2 = top[v];
        int tmp = -100000000;
        while(f1 != f2) {
            if (deep[f1] < deep[f2]) {
                swap(f1, f2);
                swap(u, v);
            }
            tmp = max(tmp, query(1, p[f1], p[u]));
            u = fa[f1]; f1 = top[u];
        }
        if (u == v) return tmp;
        if (deep[u] > deep[v]) swap(u, v);
        return max(tmp, query(1, p[son[u]], p[v]));
    }
    
    int main() {
       // freopen("in.cpp", "r", stdin);
        int t;
        scanf("%d", &t);
        while(t--) {
            int n;
            scanf("%d", &n); // input
            init();
            for (int i=0; i<n; ++i) {
                scanf("%d%d%d", &e[i][0], &e[i][1], &e[i][2]);
                addEdge(e[i][0], e[i][1]);
                addEdge(e[i][1], e[i][0]);
            }
            dfs1(1, 0, 0);
            dfs2(1, 1);
            build(1, 0, pos-1);
    
            //线段树赋值
            for (int i=0; i<n-1; ++i) {
                if (deep[e[i][0]] > deep[e[i][1]]) {
                    swap(e[i][0], e[i][1]);
                }
                update(1, p[e[i][1]], e[i][2]);
            }
            char op[10];
            int u, v;
            while(~scanf("%s", op)) {
    //            printf("%s
    ",op);
                if (op[0] == 'D') break;
                scanf("%d%d", &u, &v);
                if (op[0] == 'C') {
                    update(1, p[e[u-1][1]], v);
                }else if (op[0] == 'N') {
                    Negate(u, v);
                }else printf("%d
    ", findMax(u, v));
            }
        }
        return 0;
    }
    View Code

    HYSBZ 1036 树的统计Count

    和上一题不同的是,这是点权,线段树维护的是每个点构成的数组。这个题因为只有单点修改,所以不需要延迟标记。

    #include <stdio.h>
    #include <string.h>
    #include <iostream>
    #include <algorithm>
    #include <vector>
    #include <queue>
    #include <set>
    #include <map>
    #include <string>
    #include <math.h>
    #include <stdlib.h>
    using namespace std;
    
    const int MAXN = 30010;
    
    struct Edge
    {
        int to,next;
    }edge[MAXN*2];
    
    int head[MAXN],tot;
    int top[MAXN]; //top[v] 表示v所在的重链的顶端节点
    int fa[MAXN]; //父亲节点
    int deep[MAXN];//深度
    int num[MAXN]; //num[v]表示以v为根的子树的节点数
    int p[MAXN]; //p[v]表示v在线段树中的位置
    int fp[MAXN];//和p数组相反
    int son[MAXN];//重儿子
    int pos;
    
    void init()
    {
        tot = 0;
        memset(head,-1,sizeof(head));
        pos = 0;
        memset(son,-1,sizeof(son));
    }
    
    void addedge(int u,int v)
    {
        edge[tot].to = v; edge[tot].next = head[u]; head[u] = tot++;
    }
    
    void dfs1(int u,int pre,int d) //第一遍dfs求出fa,deep,num,son
    {
        deep[u] = d;
        fa[u] = pre;
        num[u] = 1;
        for(int i = head[u];i != -1;i = edge[i].next)
        {
            int v = edge[i].to;
            if(v != pre)
            {
                dfs1(v,u,d+1);
                num[u] += num[v];
                if(son[u] == -1 || num[v] > num[son[u]])
                    son[u] = v;
            }
        }
    }
    
    void getpos(int u,int sp)
    {
        top[u] = sp;
        p[u] = pos++; ///
        fp[p[u]] = u; ///
        if(son[u] == -1) return;
        getpos(son[u],sp);
        for(int i = head[u]; i != -1 ; i = edge[i].next)
        {
            int v = edge[i].to;
            if(v != son[u] && v != fa[u]) getpos(v,v);
        }
    }
    
    struct Node
    {
        int l,r;
        int sum;
        int Max;
    }segTree[MAXN*3];
    
    void push_up(int i)
    {
        segTree[i].sum = segTree[i<<1].sum + segTree[(i<<1)|1].sum;
        segTree[i].Max = max(segTree[i<<1].Max,segTree[(i<<1)|1].Max);
    }
    
    int s[MAXN];
    
    void build(int i,int l,int r)
    {
        segTree[i].l = l;
        segTree[i].r = r;
        if(l == r)
        {
            segTree[i].sum = segTree[i].Max = s[fp[l]]; ///赋值
            return ;
        }
        int mid = (l + r)/2;
        build(i<<1,l,mid);
        build((i<<1)|1,mid+1,r);
        push_up(i);
    }
    
    void update(int i,int k,int val)//更新线段树的第k个值为val
    {
        if(segTree[i].l == k && segTree[i].r == k)
        {
            segTree[i].sum = segTree[i].Max = val;
            return;
        }
        int mid = (segTree[i].l + segTree[i].r)/2;
        if(k <= mid)update(i<<1,k,val);
        else update((i<<1)|1,k,val);
        push_up(i);
    }
    
    int queryMax(int i,int l,int r)//查询线段树[l,r]区间的最大值
    {
        if(segTree[i].l == l && segTree[i].r == r)
        {
            return segTree[i].Max;
        }
        int mid = (segTree[i].l + segTree[i].r)/2;
        if(r <= mid) return queryMax(i<<1,l,r);
        else if(l > mid)return queryMax((i<<1)|1,l,r);
        else return max(queryMax(i<<1,l,mid),queryMax((i<<1)|1,mid+1,r));
    }
    
    int querySum(int i,int l,int r) //查询线段树[l,r]区间的和
    {
        if(segTree[i].l == l && segTree[i].r == r)
            return segTree[i].sum;
        int mid = (segTree[i].l + segTree[i].r)/2;
        if(r <= mid)return querySum(i<<1,l,r);
        else if(l > mid)return querySum((i<<1)|1,l,r);
        else return querySum(i<<1,l,mid) + querySum((i<<1)|1,mid+1,r);
    }
    
    int findMax(int u,int v)//查询u->v路径上节点的最大权值
    {
        int f1 = top[u] , f2 = top[v];
        int tmp = -1000000000;
        while(f1 != f2)
        {
            if(deep[f1] < deep[f2])
            {
                swap(f1,f2);
                swap(u,v);
            }
            tmp = max(tmp,queryMax(1,p[f1],p[u]));
            u = fa[f1];
            f1 = top[u];
        }
        if(deep[u] > deep[v]) swap(u,v);
        return max(tmp,queryMax(1,p[u],p[v])); ///
    }
    
    int findSum(int u,int v) //查询u->v路径上节点的权值的和
    {
        int f1 = top[u], f2 = top[v];
        int tmp = 0;
        while(f1 != f2)
        {
            if(deep[f1] < deep[f2])
            {
                swap(f1,f2);
                swap(u,v);
            }
            tmp += querySum(1,p[f1],p[u]);
            u = fa[f1];
            f1 = top[u];
        }
        if(deep[u] > deep[v]) swap(u,v);
        return tmp + querySum(1,p[u],p[v]); ///
    }
    
    int main()
    {
        //freopen("in.txt","r",stdin);
        //freopen("out.txt","w",stdout);
        int n;
        int q;
        char op[20];
        int u,v;
        while(scanf("%d",&n) == 1)
        {
            init();
            for(int i = 1;i < n;i++)
            {
                scanf("%d%d",&u,&v);
                addedge(u,v);
                addedge(v,u);
            }
            for(int i = 1;i <= n;i++)
                scanf("%d",&s[i]);
            dfs1(1,0,0);
            getpos(1,1);
            build(1,0,pos-1);
            scanf("%d",&q);
            while(q--)
            {
                scanf("%s%d%d",op,&u,&v);
                if(op[0] == 'C')
                    update(1,p[u],v);//修改单点的值
                else if(strcmp(op,"QMAX") == 0)
                    printf("%d
    ",findMax(u,v));//查询u->v路径上点权的最大值
                else printf("%d
    ",findSum(u,v));//查询路径上点权的和
            }
        }
        return 0;
    }
    View Code

    参考链接:http://www.cnblogs.com/kuangbin/category/507663.html

  • 相关阅读:
    java
    java
    java
    js
    java
    异常之异常处理
    面向对象之元类
    面向对象之内置方法
    面向对象之反射
    面向对象之类方法与静态方法
  • 原文地址:https://www.cnblogs.com/icode-girl/p/6036206.html
Copyright © 2011-2022 走看看