zoukankan      html  css  js  c++  java
  • 树链剖分(附带LCA和换根)——基于dfs序的树上优化

    。。。。

    有点懒;


    需要先理解几个概念:

    1. LCA

    2. 线段树(熟练,要不代码能调一天)

    3. 图论的基本知识(dfs序的性质)

    这大概就好了;

    定义

      1.重儿子:一个点所连点树size最大的,这个son被称为这个点的重儿子;

      2.轻儿子:一个点所连点除重儿子以外的都是轻儿子;

      3.重链:从一个轻儿子或根节点开始沿重儿子走所成的链;

    步骤

      在代码里,结合代码更清晰。。。(其实是太懒了)

     有重点需要注意的东西在code中有提到,仔细看。。。。

    #include<bits/stdc++.h>
    #define maxn 100007
    #define le(x) x<<1
    #define re(x) x<<1|1
    using namespace std;
    int n,m,root,mod,a[maxn],head[maxn],fa[maxn],son[maxn],cnt,tag[maxn<<2];
    //a:原始点值,fa:父亲节点,son:重儿子,tag:懒标记 
    int top[maxn],sz[maxn],id[maxn],dep[maxn],w[maxn],cent,tr[maxn<<2];
    //top:所在重链的头结点,sz:子树大小,id:dfs序,dep:深度 
    //w:dfs序所对应的值(建线段树),tr:线段树 
    struct node{
        int next,to;
    }edge[maxn<<2];
    
    template<typename type_of_scan>
    inline void scan(type_of_scan &x){
        type_of_scan f=1;x=0;char s=getchar();
        while(s<'0'||s>'9') f=s=='-'?-1:1,s=getchar();
        while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+s-'0',s=getchar();
        x*=f;
    }
    
    inline void add(int u,int v){
        edge[++cent]=(node){head[u],v};head[u]=cent;
    }
    //-----------------------------------------------------线段树红色预警 
    void push_up(int p){
        tr[p]=tr[le(p)]+tr[re(p)];
        tr[p]%=mod;
    }
    
    void build(int l,int r,int p){
        if(l==r){
            tr[p]=w[l];
            return ;
        }
        int mid=(l+r)>>1;
        build(l,mid,le(p));
        build(mid+1,r,re(p));
        push_up(p);
    }
    
    void push_down(int l,int r,int p,int k){
        int mid=l+r>>1;
        tr[le(p)]+=k*(mid-l+1),tr[re(p)]+=k*(r-mid);
        tr[le(p)]%=mod,tr[re(p)]%=mod;
        tag[le(p)]+=k,tag[re(p)]+=k;
        tag[le(p)]%=mod,tag[re(p)]%=mod;
    }
    
    void r_add(int nl,int nr,int l,int r,int p,int k){
        if(nl<=l&&nr>=r){
            tr[p]+=k*(r-l+1);tag[p]+=k;
            tr[p]%=mod,tag[p]%=mod;
            return ;
        }
        push_down(l,r,p,tag[p]),tag[p]=0;
        int mid=(l+r)>>1;
        if(nl<=mid) r_add(nl,nr,l,mid,le(p),k);
        if(nr>mid) r_add(nl,nr,mid+1,r,re(p),k);
        push_up(p);
    }
    
    int r_query(int nl,int nr,int l,int r,int p){
        int ans=0;
        if(nl<=l&&nr>=r) return tr[p];
        push_down(l,r,p,tag[p]),tag[p]=0;
        int mid=l+r>>1;
        if(nl<=mid) ans+=r_query(nl,nr,l,mid,le(p)),ans%=mod;
        if(nr>mid) ans+=r_query(nl,nr,mid+1,r,re(p)),ans%=mod;
        push_up(p);
        return ans;
    }
    
    //-----------------------------------------------------线段树结束
    //-----------------------------------------------------开始预处理 
    
    void dfs1(int x){
        sz[x]=1;//sz初始化 
        int max_part=-1;//max_part更新寻找重儿子 
        for(int i=head[x];i;i=edge[i].next){
            int y=edge[i].to;
            if(y==fa[x]) continue;
            fa[y]=x,dep[y]+=dep[x]+1;//更新子节点,准备开始继续dfs1 
            dfs1(y);sz[x]+=sz[y];//更新自身的sz数组 
            if(max_part<sz[y]) son[x]=y,max_part=sz[y];//更新重儿子 
        }
    }
    /*dfs1功能介绍
    1.更新fa数组;
    2.更新dep数组;
    3.更新sz数组; 
    4.更新son数组; 
    */ 
    
    void dfs2(int x,int t){
        id[x]=++cnt,w[cnt]=a[x],top[x]=t;//更新dfs序,dfs序所对的值,重链头节点 
        if(!son[x]) return ;
        dfs2(son[x],t);
        for(int i=head[x];i;i=edge[i].next){
            int y=edge[i].to;
            if(y==fa[x]||y==son[x]) continue;
            dfs2(y,y);
        }
    }
    /*dfs2功能介绍
    1.更新id数组;
    2.更新w数组;
    3.更新top数组
    */ 
    
    //------------------------------------------------预处理结束 
    //------------------------------------------------开始主要操作 
    
    //其实没有说的这么简单,这里重点是理解重链之间的跳跃方式,线段树的优化 
    //一个性质:重链上的dfs序是连续的,dfs1在dfs2前的原因就在此 
    
    int road_query(int x,int y){
        int ans=0;
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下面往上跳 
            ans+=r_query(id[top[x]],id[x],1,n,1);//更新重链 
            ans%=mod;
            x=fa[top[x]];//跳到重链头的fa 
        }
        if(dep[x]>dep[y]) swap(x,y);
        ans+=r_query(id[x],id[y],1,n,1);//已经在同一条重链上,直接加 
        return ans%mod;
    }
    
    int tree_query(int x){
        return r_query(id[x],id[x]+sz[x]-1,1,n,1)%mod;
    }//一个性质:在同一颗子树上的dfs序是连续的 
    
    void road_add(int x,int y,int k){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);
            r_add(id[top[x]],id[x],1,n,1,k);
            x=fa[top[x]];
        }
        if(dep[x]>dep[y]) swap(x,y);
        r_add(id[x],id[y],1,n,1,k);
        return ;
    }//类比 
    
    void tree_add(int x,int k){
        r_add(id[x],id[x]+sz[x]-1,1,n,1,k);
        return ;
    }//相同的性质 
    
    //-----------------------------------------------树链剖分 
    
    int main(){
        scan(n),scan(m),scan(root),scan(mod);
        for(int i=1;i<=n;i++) scan(a[i]);
        for(int i=1,u,v;i<=n-1;i++)
            scan(u),scan(v),add(u,v),add(v,u);
        dfs1(root),dfs2(root,root),build(1,n,1);
        for(int i=1;i<=m;i++){
            int type,x,y,z;
            scan(type);
            if(type==1) scan(x),scan(y),scan(z),
                road_add(x,y,z);
            else if(type==2) scan(x),scan(y),
                printf("%d
    ",road_query(x,y));
            else if(type==3) scan(x),scan(z),
                tree_add(x,z);
            else if(type==4) scan(x),
                printf("%d
    ",tree_query(x));
        }
        return 0;
    } 

    好了,可以开始调代码了

    拓展:

      树链剖分,作为一个优秀的暴力结构,以O(n logn logn)的时间复杂度完成路径查询,在子树查询做到了nlogn级别,所以不得不说其优秀;

      但是,它的作用远不及此:

      1.LCA查询:

        与倍增相同,树链剖分可以用logn的时间复杂度完成LCA查询(跳跃性好像更优),而他的初始化是两遍dfs O(n),理论上更优。

        可以猜测,LCA依旧运用重链跳法,然后比较即可,这里给出示范代码

    int Lca(int x,int y){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);
            x=fa[top[x]];
        }
        return dep[x]>dep[y]?y:x;
    }//只要看懂树链剖分的基本操作,这个很简单 

        可以看到,其实代码很短。。。

      2.换根操作:

        设现在的根是root,我们可以发现,换根对于路径上的操作并没有影响,但是子树操作就会影响了,所以我们分类讨论

          设u为我们要查的子树的根节点

          (1)如果root=u,那么子树即为整棵树;

          (2)设 lca 为root和u的LCA,这里可以用上面所讲的树链剖分做,如果lca!=u,那么root并不是u的子节点,所以对于查询并不影响,常规操作即可

          (3)如果lca=u,那么u节点的子树就是整颗树减去u-root这个路径上与u相挨的节点v的子树即可,这里给出logn求点v的方法

    //前提条件:要求的节点相挨的节点u,必须是root的LCA 
    int find(int x,int y){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下往上跳 
            if(fa[top[x]]==y) return top[x];//如果y是x所在重链top的父亲节点,那么就可以返回了 
            x=fa[top[x]];//
        }
        if(dep[x]<dep[y]) swap(x,y);//让y最浅 
        return son[y];// 因为在一条重链上,那么重儿子一定是路径上与要求节点相挨的 
    }

        整个操作的代码层次感我写的还是比较清楚了

    void tree_add(int x,int k){
        if(root==x) r_add(1,n,1,n,1,k);//CASE 1 
        else{
            int lca=Lca(x,root);
            if(lca!=x) r_add(id[x],id[x]+sz[x]-1,1,n,1,k);//CASE 2 
            else{
                int dson=find(x,root);
                r_add(1,n,1,n,1,k);
                r_add(id[dson],id[dson]+sz[dson]-1,1,n,1,-k);
            }//CASE 3 
        }
        return ;
    }
    
    ll tree_query(int x){
        if(root==x) return r_query(1,n,1,n,1);//CASE 1 
        else{
            int lca=Lca(x,root);
            if(lca!=x) return r_query(id[x],id[x]+sz[x]-1,1,n,1);//CASE 2 
            else{
                int dson=find(x,root);
                return r_query(1,n,1,n,1)-r_query(id[dson],id[dson]+sz[dson]-1,1,n,1);
            }//CASE 3 
        }
    }

    推荐评测网站LOJ 。。。(因为洛谷没有换根操作)

    AC代码附上

    #include<bits/stdc++.h>
    #define maxn 100007
    #define ol putchar('
    ')
    #define le(x) x<<1
    #define re(x) x<<1|1
    #define ll long long
    using namespace std;
    int n,m,head[maxn],cent,dep[maxn],son[maxn],fa[maxn],vis[maxn];
    int top[maxn],a[maxn],id[maxn],w[maxn],sz[maxn],cnt,ij,root;
    ll tr[maxn<<3],tag[maxn<<3];
    struct node{
        int next,to;
    }edge[maxn<<3];
    
    template<typename type_of_scan>
    inline void scan(type_of_scan &x){
        type_of_scan f=1;x=0;char s=getchar();
        while(s<'0'||s>'9') f=s=='-'?-1:1,s=getchar();
        while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+s-'0',s=getchar();
        x*=f;
    }
    template<typename type_of_print>
    inline void print(type_of_print x){
        if(x<0) putchar('-'),x=-x;
        if(x>9) print(x/10);
        putchar(x%10+'0');
    }
    
    inline void add(int u,int v){
        edge[++cent]=(node){head[u],v};head[u]=cent;
    }
    
    void push_up(int p){
        tr[p]=tr[le(p)]+tr[re(p)];
    }
    
    void push_down(int l,int r,int p,ll k){
        int mid=l+r>>1;
        tr[le(p)]+=1ll*(mid-l+1)*k,
        tr[re(p)]+=1ll*(r-mid)*k,
        tag[le(p)]+=k,tag[re(p)]+=k;
    }
    
    void build(int l,int r,int p){
        if(l==r){
            tr[p]=w[l];
            return ;
        }
        int mid=l+r>>1;
        build(l,mid,le(p));
        build(mid+1,r,re(p));
        push_up(p);
    }
    
    void r_add(int nl,int nr,int l,int r,int p,int k){
        if(nl<=l&&nr>=r){
            tr[p]+=1ll*(r-l+1)*k;
            tag[p]+=1ll*k;
            return ;
        }
        push_down(l,r,p,tag[p]),tag[p]=0;
        int mid=l+r>>1;
        if(nl<=mid) r_add(nl,nr,l,mid,le(p),k);
        if(nr>mid) r_add(nl,nr,mid+1,r,re(p),k);
        push_up(p);
    }
    
    ll r_query(int nl,int nr,int l,int r,int p){
        ll ans=0;
        if(nl<=l&&nr>=r) return tr[p];
        push_down(l,r,p,tag[p]),tag[p]=0;
        int mid=l+r>>1;
        if(nl<=mid) ans+=r_query(nl,nr,l,mid,le(p));
        if(nr>mid) ans+=r_query(nl,nr,mid+1,r,re(p));
        return ans;
    }
    
    void dfs1(int x){
        sz[x]=1;int max_part=-1;vis[x]++;
        for(int i=head[x];i;i=edge[i].next){
            int y=edge[i].to;
            if(y==fa[x]) continue;
            fa[y]=x;dep[y]=dep[x]+1;
            dfs1(y);sz[x]+=sz[y];
            if(max_part<sz[y]) son[x]=y,max_part=sz[y];
        }
    }
    
    void dfs2(int x,int t){
        id[x]=++cnt;w[cnt]=a[x];top[x]=t;
        if(!son[x]) return ;
        dfs2(son[x],t);
        for(int i=head[x];i;i=edge[i].next){
            int y=edge[i].to;
            if(y==son[x]||fa[x]==y) continue;
            dfs2(y,y);
        }
    }
    
    int Lca(int x,int y){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);
            x=fa[top[x]];
        }
        return dep[x]>dep[y]?y:x;
    }//只要看懂树链剖分的基本操作,这个很简单 
    
    
    //前提条件:要求的节点相挨的节点u,必须是root的LCA 
    int find(int x,int y){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下往上跳 
            if(fa[top[x]]==y) return top[x];//如果y是x所在重链top的父亲节点,那么就可以返回了 
            x=fa[top[x]];//
        }
        if(dep[x]<dep[y]) swap(x,y);//让y最浅 
        return son[y];// 因为在一条重链上,那么重儿子一定是路径上与要求节点相挨的 
    }
    
    void tree_add(int x,int k){
        if(root==x) r_add(1,n,1,n,1,k);//CASE 1 
        else{
            int lca=Lca(x,root);
            if(lca!=x) r_add(id[x],id[x]+sz[x]-1,1,n,1,k);//CASE 2 
            else{
                int dson=find(x,root);
                r_add(1,n,1,n,1,k);
                r_add(id[dson],id[dson]+sz[dson]-1,1,n,1,-k);
            }//CASE 3 
        }
        return ;
    }
    
    ll tree_query(int x){
        if(root==x) return r_query(1,n,1,n,1);//CASE 1 
        else{
            int lca=Lca(x,root);
            if(lca!=x) return r_query(id[x],id[x]+sz[x]-1,1,n,1);//CASE 2 
            else{
                int dson=find(x,root);
                return r_query(1,n,1,n,1)-r_query(id[dson],id[dson]+sz[dson]-1,1,n,1);
            }//CASE 3 
        }
    }
    
    void road_add(int x,int y,ll k){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);
            r_add(id[top[x]],id[x],1,n,1,k);
            x=fa[top[x]];
        }
        if(dep[x]>dep[y]) swap(x,y);
        r_add(id[x],id[y],1,n,1,k);
        return ;
    }
    
    ll road_query(int x,int y){
        ll ans=0;
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);
            ans+=r_query(id[top[x]],id[x],1,n,1);
            x=fa[top[x]];
        }
        if(dep[x]>dep[y]) swap(x,y);
        ans+=r_query(id[x],id[y],1,n,1);
        return ans;
    }
    
    int main(){
    //    freopen("cin.in","r",stdin);
    //    freopen("co.out","w",stdout);
        scan(n);
        for(int i=1;i<=n;i++) scan(a[i]);
        for(int i=2,v;i<=n;i++) scan(v),add(i,v),add(v,i);
        dfs1(1),dfs2(1,1),build(1,n,1);root=1;
        scan(m);
        for(int i=1;i<=m;i++){
            int type,x,y,z;
            scan(type),scan(x);
            if(type==1) root=x;
            else if(type==2) scan(y),scan(z),road_add(x,y,z);
            else if(type==3) scan(z),tree_add(x,z);
            else if(type==4) scan(y),printf("%lld
    ",road_query(x,y));
            else if(type==5) printf("%lld
    ",tree_query(x));
        }
        return 0;
    }

     

  • 相关阅读:
    CodeSmith实用技巧(十四):使用Progress对象
    .NET设计模式(5):工厂方法模式(Factory Method)
    CodeSmith实用技巧(七):从父模版拷贝属性
    CodeSmith实用技巧(十一):添加设计器的支持
    CodeSmith实用技巧(六):使用XML 属性
    CodeSmith实用技巧(三):使用FileDialogAttribute
    CodeSmith实用技巧(十二):自动执行SQL脚本
    CodeSmith中实现选择表字段的几点想法
    CodeSmith开发系列资料总结
    你真的了解.NET中的String吗?
  • 原文地址:https://www.cnblogs.com/waterflower/p/11239971.html
Copyright © 2011-2022 走看看