zoukankan      html  css  js  c++  java
  • 树上启发式合并(dsu on tree)

    有丶抽象,学到自闭

    参考的文章:

    zcysky:【学习笔记】dsu on tree

    Arpa:[Tutorial] Sack (dsu on tree)


    先康一康模板题吧:CF 600E($Lomsat$ $gelral$)

    虽然已经用莫队搞过一遍了(可以参考之前写的博客~),但这个还是差距挺大

    我们如果对于每个节点暴力统计答案,是$O(N^2)$的复杂度:最坏情况下整棵树是一条链,对于每个节点的统计平均下来是$O(N)$的

    具体是怎么做的呢?

    对于以当前节点$x$为根的子树,我们建立$cnt$和$sum$两个数组(其实只要$sum$数组就够用啦)

    $cnt[i]$:颜色$i$在子树中出现的次数

    $sum[i]$:在子树中出现次数为$i$的颜色,其颜色的序号之和

    我们还可以建立一个指针$top$,表示出现次数最多的颜色出现了多少次,在改变$cnt$数组的时候可以顺便维护下$top$

    那么,对于这个子树,我们只要跑一边$dfs$,把所有后代全部统计一波,最后的结果就是$ans[x]=sum[top]$

    现在我们希望能够降低对于每个节点统计的复杂度

    $dsu$ $on$ $tree$是$O(Ncdot logN)$的做法,需要用到一些树剖的知识

    在这道题目中,拿到了这颗树的连边,我们先用树剖怼上去

    不用太着急,只要进行第一个$dfs$、得到$son$数组(即每个节点的重儿子)就够了

    接下来的蛇皮操作需要理解一下

    对于以节点$x$为根的子树,我们这样计算其结果$ans[x]$:

    1. 将$x$的儿子分成两种,一种是重儿子,另一种是轻儿子
    2. 我们先按照最上面方法的递归计算所有轻儿子的结果,计算完以后,不对$cnt$、$sum$、$top$进行任何保留(保留与否的操作在下一层递归的第$5$步实现)
    3. 我们再递归计算重儿子的的结果,但是计算完后,保留计算重儿子答案时的$cnt$、$sum$、$top$
    4. 结束递归、回到当前节点$x$这层以后,由于保留了计算重儿子时的统计信息,我们此时对重儿子及其子树的信息是完全清楚的,但是我们依然不清楚轻儿子和它们的子树的信息,所以我们再递归的$dfs$一遍轻儿子,将信息放进计算重儿子的数组;这时,我们得到的$cnt$、$sum$、$top$与暴力统计得到的是一模一样的
    5. 但是节点$x$不一定是其父节点$fa[x]$的重儿子!如果不是,那么$dfs$一遍$x$将所有信息清空;否则就保留(这里就是第$2$、$3$步中的保留/不保留的具体实现的位置)

    我们在第$2$步仅仅是为了计算轻儿子的结果,且不想让其统计信息干扰我们需要的重儿子的统计信息

    绕的地方在于,我们在第$2$步的“不保留”实际上是在计算某个轻儿子后,将以这个轻儿子为根的子树信息全部从$cnt$、$sum$、$top$中抹去(即是在轻儿子的第$5$步实现,并不是在$x$的第$2$步) ←我一开始因为没理解这个自闭了好久

    看上去挺暴力的...分析一下为什么是$O(Ncdot logN)$

    对于每个节点$x$,它仅可能被 以其祖先为根的子树 统计,所以它被统计的次数与其到根节点的路径长度相关

    但是由于我们将重节点的统计信息保留,所以对于每条重链,只会真正意义上$dfs$到$x$一次:在重链的底端;重链上的其余节点可以通过被保留的统计信息了解$x$的情况、不需要$dfs$

    综上,$x$被统计的次数即为其到根节点的路径上链的个数,是$O(logN)$级别的

    用这个思路写出的代码如下:

    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <vector>
    using namespace std;
    
    typedef long long ll;
    const int N=100005;
    
    int n;
    int c[N];
    vector<int> v[N];
    
    int fa[N],sz[N],son[N];
    
    inline void dfs(int x,int f)
    {
        fa[x]=f;
        sz[x]=1;
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            
            dfs(next,x);
            sz[x]+=sz[next];
            if(!son[x] || sz[son[x]]<sz[next])
                son[x]=next;
        }
    }
    
    int top,cnt[N];
    ll sum[N],ans[N];
    
    inline void Add(int x,int num)
    {
        sum[cnt[c[x]]]-=(ll)c[x];
        cnt[c[x]]+=num;
        sum[cnt[c[x]]]+=(ll)c[x];
        if(sum[top+1])
            top++;
        if(!sum[top])
            top--;
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            Add(next,num);
        }
    }
    
    inline void Solve(int x,int keep)
    {
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Solve(next,0);
        }
        
        if(son[x])
                Solve(son[x],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Add(next,1);
        }
        
        sum[cnt[c[x]]]-=(ll)c[x];
        cnt[c[x]]++;
        sum[cnt[c[x]]]+=(ll)c[x];
        if(sum[top+1])
            top++;
        ans[x]=sum[top];
        
        if(!keep)
            Add(x,-1);
    }
    
    int main()
    {
    //    freopen("input.txt","r",stdin);
        scanf("%d",&n);
        for(int i=1;i<=n;i++)
            scanf("%d",&c[i]);
        for(int i=1;i<n;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            v[x].push_back(y);
            v[y].push_back(x);
        }
        
        dfs(1,0);
        Solve(1,1);
        
        for(int i=1;i<=n;i++)
            printf("%lld ",ans[i]);
        return 0;
    }
    View Code

    再整一道题:CF 1009F($Dominant$ $Indices$)

    这题做完后,感觉这种玩法比莫队还暴力...

    如果想写暴力的话,可以对于每个节点$x$遍历子树,用$cnt$数组统计各深度的节点一共有多少个、并不断更新答案深度$val$

    由于我们是一个个统计节点的,所以正确的$val$一定能够在统计的时候被更新出来:如果$cnt[dep[x]]>cnt[val]$或者$cnt[dep[x]]==cnt[val]$ && $dep[x]<val$,那么$val=dep[x]$

    如果想删除一个节点怎么办?是不是要用set什么的...

    不需要!在树上启发式合并中,唯一的删除操作存在于不保留节点信息时删空子树,所以在得到答案的过程中,我们只会加入节点、不会删除节点,即相当于每次得到的信息跟暴力得到的是完全相同的

    相比而言,莫队还得考虑删除是否为$O(1)$呢

    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <vector>
    using namespace std;
    
    const int N=1000005;
    
    int n;
    vector<int> v[N];
    
    int fa[N],sz[N],dep[N],son[N];
    
    inline void dfs(int x,int f)
    {
        fa[x]=f;
        sz[x]=1;
        dep[x]=dep[fa[x]]+1;
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            
            dfs(next,x);
            
            sz[x]+=sz[next];
            if(!son[x] || sz[next]>sz[son[x]])
                son[x]=next;
        }
    }
    
    int ans[N],cnt[N],val;
    
    inline void Add(int x,int num)
    {
        cnt[dep[x]]+=num;
        val=(cnt[dep[x]]>cnt[val] || (cnt[dep[x]]==cnt[val] && dep[x]<val)?dep[x]:val);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            Add(next,num);
        }
    }
    
    inline void Solve(int x,int keep)
    {
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Solve(next,0);
        }
        
        if(son[x])
            Solve(son[x],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Add(next,1);
        }
        
        cnt[dep[x]]+=1;
        val=(cnt[dep[x]]>cnt[val] || (cnt[dep[x]]==cnt[val] && dep[x]<val)?dep[x]:val);
        ans[x]=val;
        
        if(!keep)
            Add(x,-1);
    }
    
    int main()
    {
    //    freopen("input.txt","r",stdin);
        scanf("%d",&n);
        for(int i=1;i<n;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            v[x].push_back(y);
            v[y].push_back(x);
        }
        
        dfs(1,0);
        Solve(1,1);
        
        for(int i=1;i<=n;i++)
            printf("%d
    ",ans[i]-dep[i]);
        return 0;
    }
    View Code

    一道被我做难的题:CF 375D($Tree$ $and$ $Queries$)

    题意跟$Lomsat$ $gelral$有点像,但是求的是出现次数不少于$k$的颜色有多少种

    如果统计出现次数$cnt$,那么我们可以在当前节点$x$合并完子树后,在$cnt$数组上求区间和,也就是把第一题中的$sum$数组上的更新搬到线段树上

    不过有种更妙的方法,用类似栈的思想,只需要开一个数组来统计:

    • 如果颜色$color$被加入,那么出现次数不少于$cnt[color]$的颜色数不变(因为$color$已经在之前被统计过了),次数不少于$cnt[color]+1$的颜色数增加$1$,最后$cnt[color]++$
    • 如果颜色$color$被删去,那么出现次数不少于$cnt[color]$的颜色数减$1$,次数不少于$cnt[color]-1$的颜色数不变,最后$cnt[color]--$

    不过如果最后询问的是出现次数不少于$k$的颜色数量$ imes$颜色序号,应该是没法避免线段树了

    还是应该再多想想的...

    (这题好像莫队也能玩的很开心,没反应过来)

    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <vector>
    using namespace std;
    
    typedef pair<int,int> pii;
    const int N=100005;
    
    int SZ=1;
    int t[N<<2];
    
    inline void Modify(int k,int x)
    {
        k=k+SZ;
        t[k]+=x;
        k>>=1;
        
        while(k)
        {
            t[k]=(t[k<<1]+t[k<<1|1]);
            k>>=1;
        }
    }
    
    inline int Query(int k,int p,int a,int b)
    {
        if(b<p)
            return 0;
        if(a>=p)
            return t[k];
        
        int mid=(a+b)>>1;
        return Query(k<<1,p,a,mid)+Query(k<<1|1,p,mid+1,b);
    }
    
    int n,m;
    int c[N];
    vector<int> v[N];
    
    int fa[N],sz[N],son[N];
    
    inline void dfs(int x,int f)
    {
        fa[x]=f;
        sz[x]=1;
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            
            dfs(next,x);
            sz[x]+=sz[next];
            if(!son[x] || sz[next]>sz[son[x]])
                son[x]=next;
        }
    }
    
    vector<pii> q[N];
    int ans[N],cnt[N];
    
    inline void Add(int x,int num)
    {
        Modify(cnt[c[x]],-1);
        cnt[c[x]]+=num;
        Modify(cnt[c[x]],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            Add(next,num);
        }
    }
    
    inline void Solve(int x,int keep)
    {
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Solve(next,0);
        }
        
        if(son[x])
            Solve(son[x],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Add(next,1); 
        }
        
        Modify(cnt[c[x]],-1);
        cnt[c[x]]++;
        Modify(cnt[c[x]],1);
        
        for(int i=0;i<q[x].size();i++)
        {
            int rnk=q[x][i].first,id=q[x][i].second;
            ans[id]=Query(1,rnk,0,SZ-1);
        }
        
        if(!keep)
            Add(x,-1);
    }
    
    int main()
    {
    //    freopen("input.txt","r",stdin);
        scanf("%d%d",&n,&m);
        while(SZ<n+1)
            SZ<<=1;
        for(int i=1;i<=n;i++)
            scanf("%d",&c[i]);
        for(int i=1;i<n;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            v[x].push_back(y);
            v[y].push_back(x);
        }
        
        dfs(1,0);
        
        for(int i=1;i<=m;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            q[x].push_back(pii(y,i));
        }
        
        Solve(1,1);
        
        for(int i=1;i<=m;i++)
            printf("%d
    ",ans[i]);
        return 0;
    }
    View Code

    也是一种树上统计的技巧吧

    以后遇到题目再补在这里

    Luogu P1600 (天天爱跑步,$NOIP2016$)

    题解写在树上差分里面了

    CF Gym 259514K  ($Tree$,$2019ICPC$南昌)

    有点类似树上背包的思想,在solve轻儿子时也将信息合并到当前节点上去

    至于其他的就是pbds NB

    #include <cstdio>
    #include <vector>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    
    #include <ext/pb_ds/tree_policy.hpp>
    #include <ext/pb_ds/assoc_container.hpp>
    using namespace __gnu_pbds;
    
    typedef long long ll;
    typedef pair<int,int> pii;
    typedef tree<pii,null_type,less<pii>,rb_tree_tag,tree_order_statistics_node_update> rbtree;
    
    const int INF=1<<30;
    const int N=200005;
    
    int n,k;
    int val[N];
    vector<int> v[N];
    
    int dep[N],sz[N],son[N];
    
    void dfs(int x,int fa)
    {
        dep[x]=dep[fa]+1;
        sz[x]=1;
        
        for(int i=0;i<v[x].size();i++)
        {
            int nxt=v[x][i];
            dfs(nxt,x);
            
            sz[x]+=sz[nxt];
            if(!son[x] || sz[nxt]>sz[son[x]])
                son[x]=nxt;
        }
    }
    
    ll ans;
    rbtree t[N];
    
    void add(int x,int dlt)
    {
        if(dlt)
            t[val[x]].insert(pii(dep[x],x));
        else
            t[val[x]].erase(pii(dep[x],x));
        
        for(int i=0;i<v[x].size();i++)
            add(v[x][i],dlt);
    }
    
    void calc(int x,int lca)
    {
        int rem=val[lca]*2-val[x];
        if(rem>=0)
        {
            int ord=t[rem].order_of_key(pii(2*dep[lca]-dep[x]+k,INF));
            ans+=ord;
        }
        
        for(int i=0;i<v[x].size();i++)
            calc(v[x][i],lca);
    }
    
    void solve(int x,int lca,int keep)
    {
        for(int i=0;i<v[x].size();i++)
        {
            int nxt=v[x][i];
            if(nxt!=son[x])
                solve(nxt,nxt,0);
        }
        
        if(son[x])
            solve(son[x],son[x],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int nxt=v[x][i];
            if(nxt!=son[x])
            {
                calc(nxt,lca);
                add(nxt,1);
            }
        }
        
        t[val[x]].insert(pii(dep[x],x));
        if(!keep)
            add(x,0);
    }
    
    int main()
    {
        scanf("%d%d",&n,&k);
        for(int i=1;i<=n;i++)
            scanf("%d",&val[i]);
        for(int i=2;i<=n;i++)
        {
            int x;
            scanf("%d",&x);
            v[x].push_back(i);
        }
        
        dfs(1,0);
        solve(1,1,1);
        
        printf("%lld
    ",ans*2);
        return 0;
    }
    View Code

    (完)

  • 相关阅读:
    工作中遇到的java 内存溢出,问题排查
    java线上内存溢出问题排查步骤
    性能测试-java内存溢出问题排查
    164 01 Android 零基础入门 03 Java常用工具类01 Java异常 04 使用try…catch…finally实现异常处理 04 终止finally执行的方法
    163 01 Android 零基础入门 03 Java常用工具类01 Java异常 04 使用try…catch…finally实现异常处理 03 使用多重catch结构处理异常
    162 01 Android 零基础入门 03 Java常用工具类01 Java异常 04 使用try…catch…finally实现异常处理 02 使用try-catch结构处理异常
    161 01 Android 零基础入门 03 Java常用工具类01 Java异常 04 使用try…catch…finally实现异常处理 01 try-catch-finally简介
    160 01 Android 零基础入门 03 Java常用工具类01 Java异常 03 异常处理简介 01 异常处理分类
    159 01 Android 零基础入门 03 Java常用工具类01 Java异常 02 异常概述 02 异常分类
    158 01 Android 零基础入门 03 Java常用工具类01 Java异常 02 异常概述 01 什么是异常?
  • 原文地址:https://www.cnblogs.com/LiuRunky/p/DSU_on_Tree.html
Copyright © 2011-2022 走看看