zoukankan      html  css  js  c++  java
  • 树上启发式合并(dsu on tree)

    有丶抽象,学到自闭

    参考的文章:

    zcysky:【学习笔记】dsu on tree

    Arpa:[Tutorial] Sack (dsu on tree)


    先康一康模板题吧:CF 600E($Lomsat$ $gelral$)

    虽然已经用莫队搞过一遍了(可以参考之前写的博客~),但这个还是差距挺大

    我们如果对于每个节点暴力统计答案,是$O(N^2)$的复杂度:最坏情况下整棵树是一条链,对于每个节点的统计平均下来是$O(N)$的

    具体是怎么做的呢?

    对于以当前节点$x$为根的子树,我们建立$cnt$和$sum$两个数组(其实只要$sum$数组就够用啦)

    $cnt[i]$:颜色$i$在子树中出现的次数

    $sum[i]$:在子树中出现次数为$i$的颜色,其颜色的序号之和

    我们还可以建立一个指针$top$,表示出现次数最多的颜色出现了多少次,在改变$cnt$数组的时候可以顺便维护下$top$

    那么,对于这个子树,我们只要跑一边$dfs$,把所有后代全部统计一波,最后的结果就是$ans[x]=sum[top]$

    现在我们希望能够降低对于每个节点统计的复杂度

    $dsu$ $on$ $tree$是$O(Ncdot logN)$的做法,需要用到一些树剖的知识

    在这道题目中,拿到了这颗树的连边,我们先用树剖怼上去

    不用太着急,只要进行第一个$dfs$、得到$son$数组(即每个节点的重儿子)就够了

    接下来的蛇皮操作需要理解一下

    对于以节点$x$为根的子树,我们这样计算其结果$ans[x]$:

    1. 将$x$的儿子分成两种,一种是重儿子,另一种是轻儿子
    2. 我们先按照最上面方法的递归计算所有轻儿子的结果,计算完以后,不对$cnt$、$sum$、$top$进行任何保留(保留与否的操作在下一层递归的第$5$步实现)
    3. 我们再递归计算重儿子的的结果,但是计算完后,保留计算重儿子答案时的$cnt$、$sum$、$top$
    4. 结束递归、回到当前节点$x$这层以后,由于保留了计算重儿子时的统计信息,我们此时对重儿子及其子树的信息是完全清楚的,但是我们依然不清楚轻儿子和它们的子树的信息,所以我们再递归的$dfs$一遍轻儿子,将信息放进计算重儿子的数组;这时,我们得到的$cnt$、$sum$、$top$与暴力统计得到的是一模一样的
    5. 但是节点$x$不一定是其父节点$fa[x]$的重儿子!如果不是,那么$dfs$一遍$x$将所有信息清空;否则就保留(这里就是第$2$、$3$步中的保留/不保留的具体实现的位置)

    我们在第$2$步仅仅是为了计算轻儿子的结果,且不想让其统计信息干扰我们需要的重儿子的统计信息

    绕的地方在于,我们在第$2$步的“不保留”实际上是在计算某个轻儿子后,将以这个轻儿子为根的子树信息全部从$cnt$、$sum$、$top$中抹去(即是在轻儿子的第$5$步实现,并不是在$x$的第$2$步) ←我一开始因为没理解这个自闭了好久

    看上去挺暴力的...分析一下为什么是$O(Ncdot logN)$

    对于每个节点$x$,它仅可能被 以其祖先为根的子树 统计,所以它被统计的次数与其到根节点的路径长度相关

    但是由于我们将重节点的统计信息保留,所以对于每条重链,只会真正意义上$dfs$到$x$一次:在重链的底端;重链上的其余节点可以通过被保留的统计信息了解$x$的情况、不需要$dfs$

    综上,$x$被统计的次数即为其到根节点的路径上链的个数,是$O(logN)$级别的

    用这个思路写出的代码如下:

    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <vector>
    using namespace std;
    
    typedef long long ll;
    const int N=100005;
    
    int n;
    int c[N];
    vector<int> v[N];
    
    int fa[N],sz[N],son[N];
    
    inline void dfs(int x,int f)
    {
        fa[x]=f;
        sz[x]=1;
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            
            dfs(next,x);
            sz[x]+=sz[next];
            if(!son[x] || sz[son[x]]<sz[next])
                son[x]=next;
        }
    }
    
    int top,cnt[N];
    ll sum[N],ans[N];
    
    inline void Add(int x,int num)
    {
        sum[cnt[c[x]]]-=(ll)c[x];
        cnt[c[x]]+=num;
        sum[cnt[c[x]]]+=(ll)c[x];
        if(sum[top+1])
            top++;
        if(!sum[top])
            top--;
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            Add(next,num);
        }
    }
    
    inline void Solve(int x,int keep)
    {
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Solve(next,0);
        }
        
        if(son[x])
                Solve(son[x],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Add(next,1);
        }
        
        sum[cnt[c[x]]]-=(ll)c[x];
        cnt[c[x]]++;
        sum[cnt[c[x]]]+=(ll)c[x];
        if(sum[top+1])
            top++;
        ans[x]=sum[top];
        
        if(!keep)
            Add(x,-1);
    }
    
    int main()
    {
    //    freopen("input.txt","r",stdin);
        scanf("%d",&n);
        for(int i=1;i<=n;i++)
            scanf("%d",&c[i]);
        for(int i=1;i<n;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            v[x].push_back(y);
            v[y].push_back(x);
        }
        
        dfs(1,0);
        Solve(1,1);
        
        for(int i=1;i<=n;i++)
            printf("%lld ",ans[i]);
        return 0;
    }
    View Code

    再整一道题:CF 1009F($Dominant$ $Indices$)

    这题做完后,感觉这种玩法比莫队还暴力...

    如果想写暴力的话,可以对于每个节点$x$遍历子树,用$cnt$数组统计各深度的节点一共有多少个、并不断更新答案深度$val$

    由于我们是一个个统计节点的,所以正确的$val$一定能够在统计的时候被更新出来:如果$cnt[dep[x]]>cnt[val]$或者$cnt[dep[x]]==cnt[val]$ && $dep[x]<val$,那么$val=dep[x]$

    如果想删除一个节点怎么办?是不是要用set什么的...

    不需要!在树上启发式合并中,唯一的删除操作存在于不保留节点信息时删空子树,所以在得到答案的过程中,我们只会加入节点、不会删除节点,即相当于每次得到的信息跟暴力得到的是完全相同的

    相比而言,莫队还得考虑删除是否为$O(1)$呢

    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <vector>
    using namespace std;
    
    const int N=1000005;
    
    int n;
    vector<int> v[N];
    
    int fa[N],sz[N],dep[N],son[N];
    
    inline void dfs(int x,int f)
    {
        fa[x]=f;
        sz[x]=1;
        dep[x]=dep[fa[x]]+1;
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            
            dfs(next,x);
            
            sz[x]+=sz[next];
            if(!son[x] || sz[next]>sz[son[x]])
                son[x]=next;
        }
    }
    
    int ans[N],cnt[N],val;
    
    inline void Add(int x,int num)
    {
        cnt[dep[x]]+=num;
        val=(cnt[dep[x]]>cnt[val] || (cnt[dep[x]]==cnt[val] && dep[x]<val)?dep[x]:val);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            Add(next,num);
        }
    }
    
    inline void Solve(int x,int keep)
    {
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Solve(next,0);
        }
        
        if(son[x])
            Solve(son[x],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Add(next,1);
        }
        
        cnt[dep[x]]+=1;
        val=(cnt[dep[x]]>cnt[val] || (cnt[dep[x]]==cnt[val] && dep[x]<val)?dep[x]:val);
        ans[x]=val;
        
        if(!keep)
            Add(x,-1);
    }
    
    int main()
    {
    //    freopen("input.txt","r",stdin);
        scanf("%d",&n);
        for(int i=1;i<n;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            v[x].push_back(y);
            v[y].push_back(x);
        }
        
        dfs(1,0);
        Solve(1,1);
        
        for(int i=1;i<=n;i++)
            printf("%d
    ",ans[i]-dep[i]);
        return 0;
    }
    View Code

    一道被我做难的题:CF 375D($Tree$ $and$ $Queries$)

    题意跟$Lomsat$ $gelral$有点像,但是求的是出现次数不少于$k$的颜色有多少种

    如果统计出现次数$cnt$,那么我们可以在当前节点$x$合并完子树后,在$cnt$数组上求区间和,也就是把第一题中的$sum$数组上的更新搬到线段树上

    不过有种更妙的方法,用类似栈的思想,只需要开一个数组来统计:

    • 如果颜色$color$被加入,那么出现次数不少于$cnt[color]$的颜色数不变(因为$color$已经在之前被统计过了),次数不少于$cnt[color]+1$的颜色数增加$1$,最后$cnt[color]++$
    • 如果颜色$color$被删去,那么出现次数不少于$cnt[color]$的颜色数减$1$,次数不少于$cnt[color]-1$的颜色数不变,最后$cnt[color]--$

    不过如果最后询问的是出现次数不少于$k$的颜色数量$ imes$颜色序号,应该是没法避免线段树了

    还是应该再多想想的...

    (这题好像莫队也能玩的很开心,没反应过来)

    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <vector>
    using namespace std;
    
    typedef pair<int,int> pii;
    const int N=100005;
    
    int SZ=1;
    int t[N<<2];
    
    inline void Modify(int k,int x)
    {
        k=k+SZ;
        t[k]+=x;
        k>>=1;
        
        while(k)
        {
            t[k]=(t[k<<1]+t[k<<1|1]);
            k>>=1;
        }
    }
    
    inline int Query(int k,int p,int a,int b)
    {
        if(b<p)
            return 0;
        if(a>=p)
            return t[k];
        
        int mid=(a+b)>>1;
        return Query(k<<1,p,a,mid)+Query(k<<1|1,p,mid+1,b);
    }
    
    int n,m;
    int c[N];
    vector<int> v[N];
    
    int fa[N],sz[N],son[N];
    
    inline void dfs(int x,int f)
    {
        fa[x]=f;
        sz[x]=1;
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            
            dfs(next,x);
            sz[x]+=sz[next];
            if(!son[x] || sz[next]>sz[son[x]])
                son[x]=next;
        }
    }
    
    vector<pii> q[N];
    int ans[N],cnt[N];
    
    inline void Add(int x,int num)
    {
        Modify(cnt[c[x]],-1);
        cnt[c[x]]+=num;
        Modify(cnt[c[x]],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x])
                continue;
            Add(next,num);
        }
    }
    
    inline void Solve(int x,int keep)
    {
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Solve(next,0);
        }
        
        if(son[x])
            Solve(son[x],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int next=v[x][i];
            if(next==fa[x] || next==son[x])
                continue;
            Add(next,1); 
        }
        
        Modify(cnt[c[x]],-1);
        cnt[c[x]]++;
        Modify(cnt[c[x]],1);
        
        for(int i=0;i<q[x].size();i++)
        {
            int rnk=q[x][i].first,id=q[x][i].second;
            ans[id]=Query(1,rnk,0,SZ-1);
        }
        
        if(!keep)
            Add(x,-1);
    }
    
    int main()
    {
    //    freopen("input.txt","r",stdin);
        scanf("%d%d",&n,&m);
        while(SZ<n+1)
            SZ<<=1;
        for(int i=1;i<=n;i++)
            scanf("%d",&c[i]);
        for(int i=1;i<n;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            v[x].push_back(y);
            v[y].push_back(x);
        }
        
        dfs(1,0);
        
        for(int i=1;i<=m;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            q[x].push_back(pii(y,i));
        }
        
        Solve(1,1);
        
        for(int i=1;i<=m;i++)
            printf("%d
    ",ans[i]);
        return 0;
    }
    View Code

    也是一种树上统计的技巧吧

    以后遇到题目再补在这里

    Luogu P1600 (天天爱跑步,$NOIP2016$)

    题解写在树上差分里面了

    CF Gym 259514K  ($Tree$,$2019ICPC$南昌)

    有点类似树上背包的思想,在solve轻儿子时也将信息合并到当前节点上去

    至于其他的就是pbds NB

    #include <cstdio>
    #include <vector>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    
    #include <ext/pb_ds/tree_policy.hpp>
    #include <ext/pb_ds/assoc_container.hpp>
    using namespace __gnu_pbds;
    
    typedef long long ll;
    typedef pair<int,int> pii;
    typedef tree<pii,null_type,less<pii>,rb_tree_tag,tree_order_statistics_node_update> rbtree;
    
    const int INF=1<<30;
    const int N=200005;
    
    int n,k;
    int val[N];
    vector<int> v[N];
    
    int dep[N],sz[N],son[N];
    
    void dfs(int x,int fa)
    {
        dep[x]=dep[fa]+1;
        sz[x]=1;
        
        for(int i=0;i<v[x].size();i++)
        {
            int nxt=v[x][i];
            dfs(nxt,x);
            
            sz[x]+=sz[nxt];
            if(!son[x] || sz[nxt]>sz[son[x]])
                son[x]=nxt;
        }
    }
    
    ll ans;
    rbtree t[N];
    
    void add(int x,int dlt)
    {
        if(dlt)
            t[val[x]].insert(pii(dep[x],x));
        else
            t[val[x]].erase(pii(dep[x],x));
        
        for(int i=0;i<v[x].size();i++)
            add(v[x][i],dlt);
    }
    
    void calc(int x,int lca)
    {
        int rem=val[lca]*2-val[x];
        if(rem>=0)
        {
            int ord=t[rem].order_of_key(pii(2*dep[lca]-dep[x]+k,INF));
            ans+=ord;
        }
        
        for(int i=0;i<v[x].size();i++)
            calc(v[x][i],lca);
    }
    
    void solve(int x,int lca,int keep)
    {
        for(int i=0;i<v[x].size();i++)
        {
            int nxt=v[x][i];
            if(nxt!=son[x])
                solve(nxt,nxt,0);
        }
        
        if(son[x])
            solve(son[x],son[x],1);
        
        for(int i=0;i<v[x].size();i++)
        {
            int nxt=v[x][i];
            if(nxt!=son[x])
            {
                calc(nxt,lca);
                add(nxt,1);
            }
        }
        
        t[val[x]].insert(pii(dep[x],x));
        if(!keep)
            add(x,0);
    }
    
    int main()
    {
        scanf("%d%d",&n,&k);
        for(int i=1;i<=n;i++)
            scanf("%d",&val[i]);
        for(int i=2;i<=n;i++)
        {
            int x;
            scanf("%d",&x);
            v[x].push_back(i);
        }
        
        dfs(1,0);
        solve(1,1,1);
        
        printf("%lld
    ",ans*2);
        return 0;
    }
    View Code

    (完)

  • 相关阅读:
    mysql 中 group_concat()用法
    MySQL行转列与列转行
    mysql中find_in_set()函数的使用(转载)
    多线程中的线程安全关键字
    架构师的特征
    算法复杂度的定义
    1.ArrayList和linkedList区别
    Plsql查询clob类型字段数据
    数据库的特性与隔离级别和spring事务的传播机制和隔离级别
    java中的线程
  • 原文地址:https://www.cnblogs.com/LiuRunky/p/DSU_on_Tree.html
Copyright © 2011-2022 走看看