zoukankan      html  css  js  c++  java
  • 简述树链剖分

    @

    题目链接:luogu P3384 【模板】树链剖分
    先上完整代码,变量名解释[1]

    #include<cstdio>
    #include<algorithm>
    #include<iostream>
    using namespace std;
    typedef long long ll;
    #define N 500005
    #define RI register int
    int tot=0,n,m,rt,md;
    int fa[ N ],deep[ N ],head[ N ],size[ N ],son[ N ],id[ N ],w[ N ],nw[ N ],top[ N ];
    
    struct EDGE{
        int to,next;
    }e[ N ];
    
    inline void add( int from , int to ){
        e[ ++ tot ].to = to;
        e[ tot ].next = head[ from ];
        head[ from ] = tot;
    }
    
    template<class T>
    inline void read(T &res){
        static char ch;T flag = 1;
        while( ( ch = getchar() ) < '0' || ch > '9' ) if( ch == '-' ) flag = -1;
        res = ch - 48;
        while( ( ch = getchar() ) >= '0' && ch <= '9' ) res = res * 10 + ch - 48;
        res *= flag;
    }
    
    struct NODE{
        ll sum,flag;
        NODE *ls,*rs;
        NODE(){
            sum = flag = 0;
            ls = rs = NULL;
        }
        inline void pushdown( int l , int r )
        {
            if( flag )
            {
                int midd = ( l + r ) >> 1;
                ls->flag += flag;
                rs->flag += flag;
                ls->sum += flag * ( midd - l + 1 );
                rs->sum += flag * ( r - midd );
                flag = 0;
            }
        }
        inline void update()
        {
            sum = ls->sum + rs->sum;
        }
    }tree[ N * 2 + 5 ],*p = tree,*root;
    
    NODE *build( int l , int r )
    {
        NODE *nd = ++p;
        if( l == r )
        {
            nd->sum = nw[ l ];
            return nd;
        }
        int mid = ( l + r ) >> 1;
        nd->ls = build( l , mid );
        nd->rs = build( mid + 1 , r );
        nd->update();
        return nd;
    }
    
    ll sum( int l , int r , int x , int y , NODE *nd )
    {
        if( x <= l && r <= y )
        {
            return nd->sum;
        }
        nd->pushdown( l , r );
        int mid = ( l + r ) >> 1;
        ll res = 0;
        if( x <= mid )
          res += sum( l , mid , x , y , nd->ls );
        if( y >= mid + 1 )
          res += sum( mid + 1 , r , x , y , nd->rs );
        return res;
    }
    
    void modify( int l , int r , int x , int y , ll add , NODE *nd )
    {
        if( x <= l && r <= y ) 
        {
            nd->sum += ( r - l + 1 ) * add;
            nd->flag += add;
            return;
        }
        int mid = ( l + r ) >> 1;
        nd->pushdown( l , r );
        if( x <= mid )
            modify( l , mid , x , y , add , nd->ls );
        if( y > mid )
            modify( mid + 1 , r , x , y , add , nd->rs );
        nd->update();
    }
    
    void dfs1( int p ){
        size[ p ] = 1;
        deep[ p ] = deep[ fa[ p ] ] + 1;
        for( int i = head[ p ] ; i ; i = e[ i ].next ){
            int k = e[ i ].to;
            if( k == fa[ p ] )
              continue;
            fa[ k ] = p;
            dfs1( k );
            size[ p ] += size[ k ];
            if( size[ son[ p ] ] < size[ k ] || !son[ p ] )
              son[ p ] = k;
        }
    }
    
    void dfs2( int p , int tp ){ 
        id[ p ] = ++tot;
        nw[ tot ] = w[ p ];
        top[ p ] = tp;
        if( son[ p ] )
          dfs2( son[ p ] , tp );
        for( int i = head[ p ] ; i ; i = e[ i ].next ){
            int k = e[ i ].to;
            if( k == fa[ p ] || k == son[ p ] ) 
              continue;
            dfs2( k , k );
        }
    } 
    
    inline void ope1( int x , int y , ll add ){
        while( top[ x ] != top[ y ] ){
            if( deep[ top[ x ] ] < deep[ top[ y ] ] )
              swap( x , y );
            modify( 1 , n , id[ top[ x ] ] , id[ x ] , add , root );
            x = fa[ top[ x ] ];
        }
        if( deep[ x ] > deep[ y ] )
          swap( x , y );
        modify( 1 , n , id[ x ] , id[ y ] , add , root );
    }
    
    inline ll ope2( int x , int y ){
        ll res = 0;
        while( top[ x ] != top[ y ] ){
            if( deep[ top[ x ] ] < deep[ top[ y ] ] )
              swap( x , y );
            res += sum( 1 , n , id[ top[ x ] ] , id[ x ] , root );
            x = fa[ top[ x ] ];
        }
        if( deep[ x ] > deep[ y ] )
          swap( x , y );
        res += sum( 1 , n , id[ x ] , id[ y ] , root );
        return res;
    }
    
    inline void ope3( int x , int add ){
        modify( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , add , root );
    } 
    
    inline ll ope4( int x ){
        return sum( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , root );
    }
    
    int main()
    {
        cin>>n>>m>>rt>>md;
        for( RI i = 1 ; i <= n ; i ++ )
          read( w[ i ] ); 
        for( RI i = 1 ; i <= n - 1 ; i ++ ){
            int x,y;
            read( x ),read( y );
            add( x , y );
            add( y , x );
        }
        dfs1( rt ),tot = 0;
        dfs2( rt , rt );
        root = build( 1 , n );
        for( RI i = 1 ; i <= m ; i ++ ){
            int f;
            read( f );
            switch( f ){
                case 1:{
                    int x,y;
                    ll add;
                    read( x ),read( y ),read( add );
                    ope1( x , y , add ); 
                    break;
                }
                case 2:{
                    int x,y;
                    read( x ),read( y );
                    printf( "%lld\n" , ope2( x , y ) % md );
                    break;
                }
                case 3:{ 
                    int x;
                    ll add;
                    read( x ),read( add );
                    ope3( x , add );
                    break;
                }
                case 4:{ 
                    int x;
                    read( x );
                    printf( "%lld\n" , ope4( x ) % md );
                    break;
                }
            }
        }
        return 0;
    }
    

    前置知识

    请先能够熟练写出线段树并了解\(dfs\)序的性质

    预处理

    预处理分两次\(dfs\)
    第一次处理出各个结点的深度,\(size\),重儿子,父亲。
    第二次处理出重链,\(dfs\)序和每个点的\(top\)
    dfs1:

    void dfs1( int p ){
        size[ p ] = 1;
        deep[ p ] = deep[ fa[ p ] ] + 1;
        for( int i = head[ p ] ; i ; i = e[ i ].next ){
            int k = e[ i ].to;
            if( k == fa[ p ] )
              continue;
            fa[ k ] = p;
            dfs1( k );
            size[ p ] += size[ k ];
            if( size[ son[ p ] ] < size[ k ] || !son[ p ] )
              son[ p ] = k;
        }
    }
    

    dfs2:

    void dfs2( int p , int tp ){ 
        id[ p ] = ++tot;//每个点在dfs序里的位置
        nw[ tot ] = w[ p ];
        top[ p ] = tp;
        if( son[ p ] )
          dfs2( son[ p ] , tp );//重链
        for( int i = head[ p ] ; i ; i = e[ i ].next ){
            int k = e[ i ].to;
            if( k == fa[ p ] || k == son[ p ] ) 
              continue;
            dfs2( k , k );//轻链
        }
    } 
    

    维护

    为了更加高效的查询,我们选择用线段树来维护\(dfs\)序(树状数组等数据结构也可)。
    没什么技术含量,直接套模板即可。

    struct NODE{
        ll sum,flag;
        NODE *ls,*rs;
        NODE(){
            sum = flag = 0;
            ls = rs = NULL;
        }
        inline void pushdown( int l , int r ) 
        {
            if( flag )
            {
                int midd = ( l + r ) >> 1;
                ls->flag += flag;
                rs->flag += flag;
                ls->sum += flag * ( midd - l + 1 );
                rs->sum += flag * ( r - midd );
                flag = 0;
            }
        }
        inline void update()
        {
            sum = ls->sum + rs->sum;
        }
    }tree[ N * 2 + 5 ],*p = tree,*root;
    
    NODE *build( int l , int r )
    {
        NODE *nd = ++p;
        if( l == r )
        {
            nd->sum = nw[ l ];
            return nd;
        }
        int mid = ( l + r ) >> 1;
        nd->ls = build( l , mid );
        nd->rs = build( mid + 1 , r );
        nd->update();
        return nd;
    }
    
    ll sum( int l , int r , int x , int y , NODE *nd )
    {
        if( x <= l && r <= y )
        {
            return nd->sum;
        }
        nd->pushdown( l , r );
        int mid = ( l + r ) >> 1;
        ll res = 0;
        if( x <= mid )
          res += sum( l , mid , x , y , nd->ls );
        if( y >= mid + 1 )
          res += sum( mid + 1 , r , x , y , nd->rs );
        return res;
    }
    
    void modify( int l , int r , int x , int y , ll add , NODE *nd )
    {
        if( x <= l && r <= y ) 
        {
            nd->sum += ( r - l + 1 ) * add;
            nd->flag += add;
            return;
        }
        int mid = ( l + r ) >> 1;
        nd->pushdown( l , r );
        if( x <= mid )
            modify( l , mid , x , y , add , nd->ls );
        if( y > mid )
            modify( mid + 1 , r , x , y , add , nd->rs );
        nd->update();
    }
    

    查询

    这是核心操作(敲黑板)。

    子树有关操作

    子树查询

    由于\(dfs\)序的性质,以一个点为根的子树在\(dfs\)序中一定是连续的,所以我们只需要进行一次区间查询,需要查询的区间为:

    [根结点在\(dfs\)序中的位置,根结点在\(dfs\)序中的位置+\(size\) - 1 ]

    复杂度为\(O(logn)\)
    代码如下:

    inline ll ope4( int x ){
        return sum( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , root );
    }
    

    子树修改

    同理,进行一次区间修改
    复杂度为\(O(logn)\)
    代码如下:

    inline void ope3( int x , int add ){
        modify( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , add , root );
    } 
    

    树链有关操作

    这才是树剖的精髓所在啊!(战术后仰
    这里主要会利用重链在\(dfs\)序中一定是连续的性质,一定要记住,否则你将无法理解接下来的操作

    链查询

    操作流程:

    • 若两个点的top不同,则让top较深的点爬升到它的topfather,每次爬升进行一次区间查询[2],把结果加到res上,直到top相等为止
    • 此时两点的top为原来两点的LCA,且其中深度较浅的点就是LCA,再进行一次区间查询即可。

    最坏时间复杂度\(O(log_{2}n)\)
    代码如下:

    inline ll ope2( int x , int y ){
        ll res = 0;
        while( top[ x ] != top[ y ] ){
            if( deep[ top[ x ] ] < deep[ top[ y ] ] )//把x调整为top深度更深的的点
              swap( x , y );
            res += sum( 1 , n , id[ top[ x ] ] , id[ x ] , root );
            x = fa[ top[ x ] ];
        }
        if( deep[ x ] > deep[ y ] )
          swap( x , y );
        res += sum( 1 , n , id[ x ] , id[ y ] , root );
        return res;
    }
    

    链修改

    同理,爬升过程一模一样,只需要将链查询的区间查询改为区间修改即可。

    最坏时间复杂度O(log2n)

    代码如下:

    inline void ope1( int x , int y , ll add ){
        while( top[ x ] != top[ y ] ){
            if( deep[ top[ x ] ] < deep[ top[ y ] ] )
              swap( x , y );
            modify( 1 , n , id[ top[ x ] ] , id[ x ] , add , root );
            x = fa[ top[ x ] ];
        }
        if( deep[ x ] > deep[ y ] )
          swap( x , y );
        modify( 1 , n , id[ x ] , id[ y ] , add , root );
    }
    

    1. 变量名解释
      fa:每个结点的父结点
      deep:每个结点所在位置在树中的深度
      size:以每个结点为根的子树的大小
      son:每个结点的重儿子(即所有儿子中size最大的那个)
      id:每个结点在dfs序中的位置
      w:每个结点的权值
      nw:dfs序中,每个结点的权值
      top:每个点所在的重链的顶端 ↩︎

    2. 链查询和修改的区间为:[id[ top[ x ] ] , id[ x ]],即是这条重链。 ↩︎

  • 相关阅读:
    unomi漏洞复现
    xxl-job漏洞复现
    cgi漏洞复现
    celery漏洞复现
    bash漏洞复现
    学习ASP.NET的一些学习资源
    用EF DataBase First做一个简单的MVC3报名页面
    怎样在Word中插入代码并保持代码原始样式不变
    安装notepad++之后怎样在鼠标右键上加上Edit with notepad++
    安装Visual Studio 2010之后怎样安装MSDN Library
  • 原文地址:https://www.cnblogs.com/hzyhome/p/11658215.html
Copyright © 2011-2022 走看看