zoukankan      html  css  js  c++  java
  • 树上启发式合并

    树上启发式合并,一种美妙的黑科技,可以用普通的优化让你$n^2$变成严格$n log$,解决一些类似(树上数颜色,树上查众数)这样的问题

    首先你要知道暴力为什么是$n^2$的

    以这个图为例

    每次你从一个节点开始向下搜,你从1节点搜到3,搜完这个子树然后你需要把3存的col等信息删去再遍历另一个子树才是正确的

    那么我们每次遍历这个节点一个子树,每次搜完这棵子树都要清空当前子树储存信息这样(最差)复杂度$n^2$

    我们可以发现清空最后一个遍历的子树是没有意义的,那么我们人为把最后一个子树放到最后不就是最优的吗

    所以,首先我们先找出来重链,轻链,对于轻链我们求出子树答案,再清除子树贡献,.然后求出重链上子树答案,不清除贡献.最后我们再算一遍子树对当前节点贡献即可

    你可能会认为,这不就是一个简单的优化吗,怎么就是$n log$了

    我不知道

    它并没有优化最优复杂度而是避免了最差复杂度

    以给一棵根为1的树,每次询问子树颜色种类数为例

    代码大致如下

    #include<bits/stdc++.h>
    using namespace std;
    #define ll int
    #define r register 
    #define A 1001010
    ll head[A],nxt[A],ver[A],size[A],col[A],cnt[A],ans[A],son[A];
    ll tot=0,num,sum,nowson,n,m,xx,yy;
    inline void add(ll x,ll y){
        nxt[++tot]=head[x],head[x]=tot,ver[tot]=y;
    }
    inline ll read(){
        ll f=1,x=0;char c=getchar();
        while(!isdigit(c)){
            if(c=='-') f=-1;
            c=getchar();
        }
        while(isdigit(c))
            x=(x<<1)+(x<<3)+(c^48),c=getchar();
        return f*x;
    }
    void dfs(ll x,ll fa){
        size[x]=1;
        for(ll i=head[x];i;i=nxt[i]){
            ll y=ver[i];
            if(y==fa) continue;
            dfs(y,x);
            size[x]+=size[y];
            if(size[son[x]]<size[y])
                son[x]=y;
        }
    }
    void cal(ll x,ll fa,ll val){
        if(!cnt[col[x]]) ++sum;
        cnt[col[x]]+=val;
        for(ll i=head[x];i;i=nxt[i]){
            ll y=ver[i];
            if(y==fa||y==nowson) continue;
            cal(y,x,val); 
        }
    }
    void dsu(ll x,ll fa,bool op){
        for(ll i=head[x];i;i=nxt[i]){
            ll y=ver[i];
            if(y==fa||y==son[x])
                continue;
            dsu(y,x,0);
            //从轻儿子出发
        }
        if(son[x])
            dsu(son[x],x,1),nowson=son[x];
        cal(x,fa,1);nowson=0;
        ans[x]=sum;
        if(!op){
            cal(x,fa,-1);
            sum=0;
        }
    }
    int main(){
        n=read();
        for(ll i=1;i<=n-1;i++){
            xx=read(),yy=read();
            add(xx,yy),add(yy,xx);
        }
        for(ll i=1;i<=n;i++)
            col[i]=read();
        dfs(1,0);
        dsu(1,0,1);
        m=read();
        for(ll i=1;i<=m;i++){
            xx=read();
            printf("%d
    ",ans[xx]);
        }
    }

    另一种打法

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<cmath>
    using namespace std;
    #define R register
    #define ll long long
    inline ll read(){
        ll aa=0;R int bb=1;char cc=getchar();
        while(cc<'0'||cc>'9')
            {if(cc=='-')bb=-1;cc=getchar();}
        while(cc>='0'&&cc<='9')
            {aa=(aa<<1)+(aa<<3)+(cc^48);cc=getchar();}
        return aa*bb;
    }
    const int N=1e5+3;
    struct edge{
        int v,last;
    }ed[N<<1];
    int first[N],tot;
    inline void add(int x,int y)
    {
        ed[++tot].v=y;
        ed[tot].last=first[x];
        first[x]=tot;
    }
    int n,m,c[N],son[N],cnt[N],ans[N],siz[N];
    void dfsi(int x,int fa)
    {
        siz[x]=1;
        for(R int i=first[x],v;i;i=ed[i].last){
            v=ed[i].v;
            if(v==fa)continue;
            dfsi(v,x);
            siz[x]+=siz[v];
            if(siz[v]>siz[son[x]])son[x]=v;
        }
        return;
    }
    int dfsj(int x,int fa,int bs,int kep)
    {
        if(kep){
            for(R int i=first[x],v;i;i=ed[i].last){
                v=ed[i].v;
                if(v!=fa&&v!=son[x])
                    dfsj(v,x,0,1);
            }
        }
        int res=0;
        if(son[x])res+=dfsj(son[x],x,1,kep);
        for(R int i=first[x],v;i;i=ed[i].last){
            v=ed[i].v;
            if(v!=fa&&v!=son[x])
                res+=dfsj(v,x,0,0);
        }
        if(!cnt[c[x]])res++;
        cnt[c[x]]++;
        if(kep){
            ans[x]=res;
            if(!bs)memset(cnt,0,sizeof(cnt));
        }
        return res;
    }
    int main()
    {
        n=read();
        for(R int i=1,x,y;i<n;++i){
            x=read();y=read();
            add(x,y);add(y,x);
        }
        for(R int i=1;i<=n;++i)c[i]=read();
        dfsi(1,0); dfsj(1,0,1,1);
        m=read();
        for(R int i=1,x;i<=m;++i){
            x=read();
            printf("%d
    ",ans[x]);
        }
        return 0;
    }

    虽然好像没什么区别

    然后再看一道例题

    有一棵 n 个节点的以 1 号节点为根的树,每个节点上有一个小桶,节点u上的小桶可以容纳${k_u}$ 个小球,ljh每次可以给一个节点到根路径上的所有节点的小桶内放一个小球,如果这个节点的小桶满了则不能放进这个节点,最后多次询问某个节点值

    首先暴力不能过

    直接权值线段树+线段树合并很难维护,树链剖分也难以维护,但我们直接树上启发式合并+线段树暴力修改可以维护。

    首先单纯线段树暴力修改可以维护,但这会超时。于是我们用启发式合并作为时间复杂度保证,莫名奇妙AC了这个题

    #include<bits/stdc++.h>
    using namespace std;
    #define ll long long
    #define A 1001010
    ll head[A],nxt[A],ver[A],size[A],son[A],tong[A],col[A],getfa[A],isbigson[A],ans[A],al[A];
    vector<pair<ll,ll> >v[A];
    map<ll,ll>mp;
    ll n,m,tot=0,Q,wwb=0;
    struct tree{
        ll l,r,f,x,t,c;
    }tr[A];
    void add(ll x,ll y){
        nxt[++tot]=head[x],head[x]=tot,ver[tot]=y;
    }
    void prdfs(ll x,ll fa){
        size[x]=v[x].size()+1;
        for(ll i=head[x];i;i=nxt[i]){
            ll y=ver[i];
            if(y==fa) continue;
            prdfs(y,x);
            size[x]+=size[y];
            if(size[son[x]]<size[y])
                isbigson[son[x]]=0,son[x]=y,isbigson[y]=1;
        }
    }
    void built(ll p,ll l,ll r){
        tr[p].l=l,tr[p].r=r;
        if(tr[p].l==tr[p].r){
            return ;
        }
        ll mid=(l+r)>>1;
        built(p<<1,l,mid);
        built(p<<1|1,mid+1,r);
    }
    ll ask(ll p,ll pos){
        if(pos>=tr[p].t) return tr[p].c;
        return (pos>=tr[p<<1].t?tr[p<<1].c+ask(p<<1|1,pos-tr[p<<1].t):ask(p<<1,pos));
    }
    void insert(ll p,ll pos,ll t,ll c){
        if(tr[p].l==tr[p].r)
            {tr[p].t+=t;tr[p].c+=c;return;}
        if(pos<=tr[p<<1].r)
            insert(p<<1,pos,t,c);
        else 
            insert(p<<1|1,pos,t,c);
        tr[p].t=tr[p<<1].t+tr[p<<1|1].t;
        tr[p].c=tr[p<<1].c+tr[p<<1|1].c;
    }
    void up(ll x,ll fa){
        if(v[getfa[x]].size()<v[getfa[fa]].size()){
            for(ll i=0;i<v[getfa[x]].size();i++)
                v[getfa[fa]].push_back(v[getfa[x]][i]);
            v[getfa[x]].clear();
            getfa[x]=getfa[fa];
        }
        else{
            for(ll i=0;i<v[getfa[fa]].size();i++)
                v[getfa[x]].push_back(v[getfa[fa]][i]);
            v[getfa[fa]].clear();
            getfa[fa]=getfa[x];
        }
    }
    void dfs(ll x,ll fa){
    
        for(ll i=head[x];i;i=nxt[i]){
            ll y=ver[i];
            if(y==fa||y==son[x])    continue;
            dfs(y,x);
        }
        if(son[x]) dfs(son[x],x);
        for(ll i=0;i<v[getfa[x]].size();i++){
            ll tim=v[getfa[x]][i].first,col=v[getfa[x]][i].second;
            if(!al[col])    al[col]=tim,insert(1,tim,1,1);
            else if(al[col]>tim){
                insert(1,al[col],0,-1);
                insert(1,tim,1,1);
                al[col]=tim;
            }
            else insert(1,tim,1,0);
        }
    //    printf("t=%lld tong=%lld
    ",tr[1].t,tong[x]);
        ans[x]=ask(1,min(tr[1].t,tong[x]));
        if(son[x])
            up(son[x],x);
        if(!isbigson[x]){
            for(ll i=0;i<v[getfa[x]].size();i++){
                ll tim=v[getfa[x]][i].first,col=v[getfa[x]][i].second;
                if(al[col]==tim)
                    insert(1,tim,-1,-1),al[col]=0;
                else 
                    insert(1,tim,-1,0);
            }
            up(x,fa);
        }    
    /*    for(ll i=1;i<=5;i++){
            printf("ans=%lld ",ans[i]);
        }
    *//*    cout<<endl;*/
    }
    int main(){
        scanf("%lld",&n);
        for(ll i=1;i<n;i++){
            ll xx,yy;
            scanf("%lld%lld",&xx,&yy);
            add(xx,yy),add(yy,xx);
        }
        for(ll i=1;i<=n;i++){
            scanf("%lld",&tong[i]);
            getfa[i]=i;
        }
        prdfs(1,0);
        scanf("%lld",&m);built(1,1,m);
        for(ll i=1,x,c;i<=m;i++){
            scanf("%lld%lld",&x,&c);
            if(!mp[c])
                mp[c]=++wwb;
            //离散化
            v[x].push_back(make_pair(i,mp[c]));
        }
        dfs(1,0);
        scanf("%lld",&Q);
        for(ll i=1,x;i<=Q;i++){
            scanf("%lld",&x);
            printf("%lld
    ",ans[x]);
        }
    }
    我已没有下降的余地
  • 相关阅读:
    python字符串
    Python问题:SyntaxError: Non-ASCII character 'xe2' in file
    windows 运行库与dll文件
    sublime python 配置内容
    sublime ctrl b突然不能用解决方法
    c++ primer 的 textquery 例子。
    虚函数表
    理解各种数据类型和简单类在内存中的存在形式。
    最短路径纯贪心算法。
    中缀表达式生成二叉树
  • 原文地址:https://www.cnblogs.com/znsbc-13/p/11272999.html
Copyright © 2011-2022 走看看