zoukankan      html  css  js  c++  java
  • 树链剖分学习

    树链剖分,顾名思义,就是将一棵树上的节点按照一个特殊的方式重新编号,这样我们就可以利用一些数据结构去优化加速一些树上的操作;

    现在要介绍的是重链剖分;

    首先明确一些概念:

    重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;

    轻儿子:父亲节点中除了重儿子以外的儿子;

    重边:父亲结点和重儿子连成的边;

    轻边:父亲节点和轻儿子连成的边;

    重链:由多条重边连接而成的路径;

    轻链:由多条轻边连接而成的路径;

    有了这些概念,我们就可以愉快地剖分了;

    具体操作就是先用两个 dfs作出以下变量:

    名称 解释
    fa[u] 保存结点u的父亲节点
    dep[u] 保存结点u的深度值
    size[u] 保存以u为根的子树节点个数
    son[u] 保存重儿子
    rk[u] 保存当前dfs标号在树中所对应的节点
    top[u] 保存当前节点所在链的顶端节点
    dfn[u] 保存树中每个节点剖分以后的新编号(DFS的执行顺序)

    然后在写一棵线段树(某数据结构),将树上节点以 dfs序映射到线段上,然后就可以优化树上操作了,这就是树链剖分;

    附上代码:

    #include<cstdio>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    using namespace std;
    const int N = 1e5+10;
    
    int n,m,r,mod;
    int val[N];
    int dfn[N],top[N],rk[N];
    int dep[N],fa[N],size[N],son[N];
    struct node{
        int l,r,ls,rs,sum,lazy;
    }a[N<<4];
    struct edge{
        int next,to;
    }e[N<<4];
    int head[N],cnt;
    
    void add(int u,int v){
        e[++cnt]=(edge){head[u],v};
        head[u]=cnt;
    }
    
    void dfs1(int x){
        size[x]=1;
        dep[x]=dep[fa[x]]+1;
        for(int v,i=head[x];i;i=e[i].next){
            v=e[i].to;
            if(dep[v]) continue;
            fa[v]=x;
            dfs1(v);
            size[x]+=size[v];
            if(size[v]>size[son[x]]) son[x]=v;
        }
    }
    
    void dfs2(int x,int t){
        top[x]=t;
        dfn[x]=++cnt;
        rk[cnt]=x;
        if(son[x]) dfs2(son[x],t);
        for(int i=head[x],v;i;i=e[i].next){
            v=e[i].to;
            if(v==fa[x]) continue;
            if(son[x]!=v) dfs2(v,v);
        }
    }
    
    void pushup(int o){a[o].sum=(a[a[o].ls].sum+a[a[o].rs].sum)%mod;}
    
    void build(int o,int l,int r){
        if(l==r){
            a[o].sum=val[rk[l]];
            a[o].l=a[o].r=l;
            return ;
        }
        int mid=(l+r)>>1;
        a[o].ls=++cnt,a[o].rs=++cnt;
        build(a[o].ls,l,mid);
        build(a[o].rs,mid+1,r);
        a[o].l=a[a[o].ls].l;
        a[o].r=a[a[o].rs].r;
        pushup(o);
    }
    
    void pushdown(int o){
        if(a[o].lazy){
            int ls=a[o].ls,rs=a[o].rs;
            a[ls].lazy=(a[ls].lazy+a[o].lazy)%mod;
            a[rs].lazy=(a[rs].lazy+a[o].lazy)%mod;
            a[ls].sum=(a[ls].sum+(a[ls].r-a[ls].l+1)*a[o].lazy)%mod;
            a[rs].sum=(a[rs].sum+(a[rs].r-a[rs].l+1)*a[o].lazy)%mod;
            a[o].lazy=0;
        }
    }
    
    void updata(int o,int x,int y,int d){
        if(a[o].l>=x&&a[o].r<=y){
            a[o].lazy+=d;
            a[o].sum=(a[o].sum+(a[o].r-a[o].l+1)*d)%mod;
            return ;
        }
        pushdown(o);
        int mid=(a[o].l+a[o].r)>>1;
        if(x<=mid) updata(a[o].ls,x,y,d);
        if(y>mid) updata(a[o].rs,x,y,d);
        pushup(o);
    }
    
    int query(int o,int x,int y){
        if(a[o].l>=x&&a[o].r<=y) return a[o].sum;
        pushdown(o);
        int mid=(a[o].l+a[o].r)>>1;
        int rel=0;
        if(x<=mid) rel=(rel+query(a[o].ls,x,y))%mod;
        if(y>mid) rel=(rel+query(a[o].rs,x,y))%mod;
        return rel;
    }
    
    int getsum(int x,int y){
        int rel=0;
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);
            rel=(rel+query(1,dfn[top[x]],dfn[x]))%mod;
            x=fa[top[x]];
        }
        if(dep[x]>dep[y]) swap(x,y);
        rel=(rel+query(1,dfn[x],dfn[y]))%mod;
        return rel;
    }
    
    int change(int x,int y,int d){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]]) swap(x,y);
            updata(1,dfn[top[x]],dfn[x],d);
            x=fa[top[x]];
        }
        if(dep[x]>dep[y]) swap(x,y);
        updata(1,dfn[x],dfn[y],d);
    }
    
    int main()
    {
        scanf("%d%d%d%d",&n,&m,&r,&mod);
        for(int i=1;i<=n;++i) scanf("%d",&val[i]);
        for(int i=1,x,y;i<n;++i){
            scanf("%d%d",&x,&y);
            add(x,y);
            add(y,x);
        }
        cnt=0,dfs1(r),dfs2(r,r);
        build(1,1,n);
        for(int i=1,op,x,y,z;i<=m;++i){
            scanf("%d",&op);
            if(op==1){
                scanf("%d%d%d",&x,&y,&z);
                change(x,y,z);
            }
            if(op==2){
                scanf("%d%d",&x,&y);
                printf("%d
    ",getsum(x,y));
            }
            if(op==3){
                scanf("%d%d",&x,&z);
                updata(1,dfn[x],dfn[x]+size[x]-1,z);
            }
            if(op==4){
                scanf("%d",&x);
                printf("%d
    ",query(1,dfn[x],dfn[x]+size[x]-1));
            }
        }
        return 0;
    }
  • 相关阅读:
    Java中编写代码出现异常,如何抛出异常,如何捕获异常
    用Java制作斗地主
    Java—Map接口中的常用方法
    Java—增强for循环与for循环的区别/泛型通配符/LinkedList集合
    Java—包装类/System类/Math类/Arrays类/大数据运算/Collection接口/Iterator迭代器
    Java—时间的原点 计算时间所使用的 Date类/DateFormat类/Calendar类
    Java—匿名对象/内部类/访问修饰符/代码块
    Java—构造方法及this/super/final/static关键字
    Java—接口
    URL Protocol打开应用程序并传递程序启动参数(Windows、Mac)
  • 原文地址:https://www.cnblogs.com/nnezgy/p/11578685.html
Copyright © 2011-2022 走看看