zoukankan      html  css  js  c++  java
  • 树链剖分学习笔记

    树链剖分主要解决在树上某条路径上或某棵子树的sum与最值

    入门树链剖分,最重要的概念是重儿子,用son[]记录。son[i]代表的是以i为根,节点最多子树根的编号。通过son,我们将树中的边分为两种,轻边和重边。重边是每一个点与它重儿子的连边。将链续的重边串起来,构成了一条条重链。对于每个点,它一定在某条重链上,特殊的,一个单独的点也可以是一条重链。对于一条重链上的点,他们的dfs序是连续的,对与一颗子树上的点,他们的dfs序也是连续的,于是我们将树上的点转化成一个区间,在区间用线段树上求解或修改。

    树链剖分的核心便是如何将树剖分成若干链。

    在此之前,先了解以下七个参数的意义。

    fa[x] x的父亲节点编号

    dep[x] x的深度

    size[x] 以x为根子树的节点,用来求son

    seg[x] 以son为基础的dfs2序,即其在线段树上的编号

    rev[x] 用来将线段树上的编号转化为原编号。

    son[x] 记录重儿子  

    top[x] 记录x所在重链dep最小的点

    树链剖分需要的前置知识有DFS+LCA+线段树

    我们在两次dfs中求出上述七个参数

    dfs1:求出dep,f,size,son。
    inline void dfs1(int u,int f){
        int i,v;
        size[u]=1;
        fa[u]=f;
        dep[u]=dep[f]+1;
        for(i=fir[u];v=to[i],i;i=nex[i]){
            if(v!=f){
                dfs1(v,u);
                size[u]+=size[v];
                if(size[v]>size[son[u]])//更新重儿子
                    son[u]=v;
            }
        }
    }

    dfs2:求出rev,seg,top。

    inline void dfs2(int u,int f){
        int i,v;
        if(son[u]){//优先遍历重儿子。
            seg[son[u]]=++tim;
            rev[tim]=son[u];
            top[son[u]]=top[u];//重儿子的top,就是u的top。
            dfs2(son[u],u);
        }
        for(i=fir[u];v=to[i],i;i=nex[i])
            if(!top[v]){//访问轻边
                seg[v]=++tim;
                rev[tim]=v;
                top[v]=v;//轻边单独开了一条链,top是本身
                dfs2(v,u);
            }
    }

    两遍dfs就把整棵树划分为若干条链,剩下的就交给线段树解决了。

    首先是建树

    inline void build(int k,int l,int r){
        if(l==r){
            sum[k]=ma[k]=w[l];//w是每个点的全值
            return;
        }
        int mid=l+r>>1;
        build(k<<1,l,mid);
        build(k<<1|1,mid+1,r);
        ma[k]=max(ma[k<<1],ma[k<<1|1]);
        sum[k]=sum[k<<1]+sum[k<<1|1];
    }

    线段树的查询和修改类似。

    inline void change(int k,int l,int r,int val,int pos){
    //pos是当前要修改的的位置,val是改变后的值
    if(l>pos||r<pos) return; if(l==r&&l==pos){ ma[k]=sum[k]=val; return; } int mid=l+r>>1; change(k<<1,l,mid,val,pos); change(k<<1|1,mid+1,r,val,pos); sum[k]=sum[k<<1]+sum[k<<1|1]; ma[k]=max(ma[k<<1],ma[k<<1|1]); } inline void query(int k,int x,int y,int l,int r){
    //l~r为需要修改的区间
    if(x>r||y<l) return; if(x>=l&&y<=r){ SUM+=sum[k]; MAX=max(ma[k],MAX); return; } int mid=x+y>>1; query(k<<1,x,mid,l,r); query(k<<1|1,mid+1,y,l,r); }

    最后,我们只需要知道只需要知道哪些点对这条路径有贡献,统计他们的贡献即可。

    inline void ask(int x,int y){

    inline void ask(int x,int y){  
    int fx=top[x],fy=top[y]; while(fx!=fy){//如果他们不在同一重链上 if(dep[fx]<dep[fy]) swap(x,y),swap(fx,fy);//选取深度大的那一条, query(1,1,tim,seg[fx],seg[x]);//注意要将原编号转化为dfs序编号 x=fa[x],fx=top[x]; }
      //如果他们在一条链上了,再统计x~y路径的贡献
    if(dep[x]>dep[y]) swap(x,y);//保证x的编号小等于y query(1,1,tim,seg[x],seg[y]); }

    下面附上一道模板题

    树的统计

    一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
    
    我们将以下面的形式来要求你对这棵树完成一些操作:
    
    I. CHANGE u t : 把结点u的权值改为t
    
    II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
    
    III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
    
    注意:从点u到点v的路径上的节点包括u和v本身
    
    输入格式
    输入文件的第一行为一个整数n,表示节点的个数。
    
    接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
    
    接下来一行n个整数,第i个整数wi表示节点i的权值。
    
    接下来1行,为一个整数q,表示操作的总数。
    
    接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
    
    输出格式
    对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
    
    输入输出样例
    输入
    4
    1 2
    2 3
    4 1
    4 2 1 3
    12
    QMAX 3 4
    QMAX 3 3
    QMAX 3 2
    QMAX 2 3
    QSUM 3 4
    QSUM 2 1
    CHANGE 1 5
    QMAX 3 4
    CHANGE 3 6
    QMAX 3 4
    QMAX 2 4
    QSUM 3 4
    输出 
    4
    1
    2
    2
    10
    6
    5
    6
    5
    16
    题目描述
    #include<cstdio>
    #include<iostream>
    #include<cstring>
    #define max(x,y) (x>y?x:y)
    #define N 100000
    using namespace std;
    int n,m,tot,tim,SUM,MAX;
    int fir[N],to[N],nex[N];
    int seg[N],rev[N],size[N],son[N],dep[N],top[N],fa[N];
    int sum[N],ma[N],w[N];
    inline void r(int &x){
        bool sign=1;
        x=0;
        char ch=getchar();
        while(ch<'0'||ch>'9') ch=getchar();
        if(ch=='-') sign=0,ch=getchar();
        while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
        x=sign?x:-x;
    }
    inline void add(int x,int y){
        to[++tot]=y,nex[tot]=fir[x],fir[x]=tot;
        to[++tot]=x,nex[tot]=fir[y],fir[y]=tot;
    }
    inline void dfs1(int u,int f){
        int i,v;
        size[u]=1;
        fa[u]=f;
        dep[u]=dep[f]+1;
        for(i=fir[u];v=to[i],i;i=nex[i]){
            if(v!=f){
                dfs1(v,u);
                size[u]+=size[v];
                if(size[v]>size[son[u]])
                    son[u]=v;
            }
        }
    }
    inline void dfs2(int u,int f){
        int i,v;
        if(son[u]){
            seg[son[u]]=++tim;
            rev[tim]=son[u];
            top[son[u]]=top[u];
            dfs2(son[u],u);
        }
        for(i=fir[u];v=to[i],i;i=nex[i])
            if(!top[v]){
                seg[v]=++tim;
                rev[tim]=v;
                top[v]=v;
                dfs2(v,u);
            }
    }
    inline void build(int k,int l,int r){
        if(l==r){
            sum[k]=ma[k]=w[l];
            return;
        }
        int mid=l+r>>1;
        build(k<<1,l,mid);
        build(k<<1|1,mid+1,r);
        ma[k]=max(ma[k<<1],ma[k<<1|1]);
        sum[k]=sum[k<<1]+sum[k<<1|1];
    }
    inline void change(int k,int l,int r,int val,int pos){
        if(l>pos||r<pos)
            return;
        if(l==r&&l==pos){
            ma[k]=sum[k]=val;
            return;
        }
        int mid=l+r>>1;
        change(k<<1,l,mid,val,pos);
        change(k<<1|1,mid+1,r,val,pos);
        sum[k]=sum[k<<1]+sum[k<<1|1];
        ma[k]=max(ma[k<<1],ma[k<<1|1]);
    }
    inline void query(int k,int x,int y,int l,int r){
        if(x>r||y<l)
            return;
        if(x>=l&&y<=r){
            SUM+=sum[k];
            MAX=max(ma[k],MAX);
            return;
        }
        int mid=x+y>>1;
        query(k<<1,x,mid,l,r);
        query(k<<1|1,mid+1,y,l,r);
    }
    inline void ask(int x,int y){
        int fx=top[x],fy=top[y];
        while(fx!=fy){
            if(dep[fx]<dep[fy]) swap(x,y),swap(fx,fy);
            query(1,1,tim,seg[fx],seg[x]);
            x=fa[x],fx=top[x];
        }
        if(dep[x]>dep[y]) swap(x,y);
        query(1,1,tim,seg[x],seg[y]);
    }
    int main()
    {
        int i,j,x,y;
        char op[10];
        r(n);
        for(i=1;i<n;i++){
            r(x),r(y);
            add(x,y);
        }
        for(i=1;i<=n;i++)
            r(w[i]);
        tim=seg[1]=top[1]=rev[1]=1;
        dfs1(1,0);
        dfs2(1,0);
        build(1,1,tim);
        r(m);
        for(i=1;i<=m;i++){
            scanf("%s",op);
            r(x),r(y);
            SUM=0;
            MAX=-N;
            switch(op[1]){
                case 'M':{
                    ask(x,y);
                    printf("%d
    ",MAX);
                    break;
                }
                case 'S':{
                    ask(x,y);
                    printf("%d
    ",SUM);
                    break;
                }
                case 'H':{
                    change(1,1,tim,y,seg[x]);
                    break;
                }
            }
        }
    }
    View Code

    (代码有误)

    2019-09-04

  • 相关阅读:
    CentOS6.4 64位系统安装jdk
    oracle安装界面中文乱码解决
    亦步亦趋在CentOS 6.4下安装Oracle 11gR2(x64)
    CentOS 6.3(x86_64)下安装Oracle 10g R2
    Spring中映射Mongodb中注解的解释
    MongoDB 创建基础索引、组合索引、唯一索引以及优化
    MongoDB 用MongoTemplate查询指定时间范围的数据
    Java获取泛化类型
    SpringBoot标准Properties
    java如何获取一个对象的大小【转】
  • 原文地址:https://www.cnblogs.com/quitter/p/11455775.html
Copyright © 2011-2022 走看看