zoukankan      html  css  js  c++  java
  • 虚树入门

    虚树,顾名思义,就是假的树.

    在树形dp中有很大的优化作用.

    虚树主要针对于树中关键点的询问.我们仅仅对关键点及其lca建一棵树.这样只要保证sigmak在时间复杂度内即可.

    以下是建树的模板

    q=read();
    for(int i=1;i<=q;++i)
    {
        num=read();
        for(int j=1;j<=num;++j) b[j]=read(),vis[b[j]]=true;//标记关键点. 
        sort(b+1,b+num+1,cmp);//按照dfn排序 
        stak[top=1]=b[1];//强行加入第一个点. 
        for(int j=2;j<=num;++j)
        {
            int now=b[j];
            int lc=lca(now,stak[top]);
            while(1)
            {
                if(deep[lc]>=deep[stak[top-1]])//如果lca为top,或top-1,或在两者之间 
                {
                    if(lc!=stak[top])//不等于top 
                    {
                        add2(lc,stak[top]);//先连边 
                        if(lc!=stak[top-1]) stak[top]=lc;//如果在两者之间,去掉top,加入lca 
                        else --top;//否则为top-1,直接去掉top即可. 
                    }
                    break;
                }
                else {add2(stak[top-1],stak[top]);top--;}//lca在top-1之上,top-1向top连边,去掉top1 
            }
            stak[++top]=now;//最后把now加入栈中. 
        }
        while(--top) add2(stak[top],stak[top+1]);//最后将最右链加入加入虚树 
        dfs(stak[1]);//从最上面的点开始dfs 

    这里用栈维护了虚树的最右链,dfs中记得将虚树的信息清空即可.

    我觉得最难得不是虚树的建立,毕竟这就是一个模板,而是建立虚树后的dp转移...头大...

    [SDOI2011]消耗战

    这个题要求所有的关键点都不能到达1号点的最小代价.

    看到sigma(ki)<=500000,就知道要用到虚树(要养成好习惯).

    我们先考虑从普通的dp入手,再探索虚树上应该如何dp.

    我们设f[i]表示以i为根的子树内的关键点都不与1联通的最小代价.

    考虑当前x的状态如何转移.

    首先如果x是关键点,那f[x]只能等于v(fa[x],x).也就是必须切断x的父亲与x的联系。这样x及其子树都不可能与1联通.

    倘若x不是关键点,那f[x]=min(sum[x],v(fa[x],x)).sum[x]=sigmaf[y].(y=x.son)

    好了,这样普通的dp就只能达到这种地步了.

    如果我们把这种dp放到虚树上会是什么样呢?由于我们将许多没用的点都抽离出去了,所以如果一个点是关键带你的话,我们无法做到查询

    v(fa[x],x)的值.那我们思考当想要将x的关键点拦截的话,付出他的最小代价究竟是什么,是点x到1的最小的边权.

    那我们在之前的dfs中预处理出来这个东西.之后按照上面的转移即可.

    #include<bits/stdc++.h>
    #define ll long long
    #define min(a,b) a<b?a:b
    using namespace std;
    const int N=500500;
    int link1[N],tot1,link2[N],tot2,n,deep[N],f[N][25],q;
    int b[N],num,dfn[N],stak[N],top;
    ll minv[N];
    bool vis[N];
    struct edge{int y,next;ll v;}a1[N<<1],a2[N<<1]; 
    inline int read()
    {
        int x=0,ff=1;
        char ch=getchar();
        while(!isdigit(ch)) {if(ch=='-') ff=-1;ch=getchar();}
        while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
        return x*ff;
    }
    inline void add1(int x,int y,int v)
    {
        a1[++tot1].y=y;
        a1[tot1].v=v;
        a1[tot1].next=link1[x];
        link1[x]=tot1;
    }
    inline void add2(int x,int y)
    {
        a2[++tot2].y=y;
        a2[tot2].next=link2[x];
        link2[x]=tot2;
    }
    inline void dfs1(int x,int fa)
    {
        dfn[x]=++num;
        for(int i=link1[x];i;i=a1[i].next)
        {
            int y=a1[i].y;
            if(y==fa) continue;
            deep[y]=deep[x]+1;
            f[y][0]=x;
            for(int j=1;j<=20;++j) f[y][j]=f[f[y][j-1]][j-1];
            minv[y]=min(minv[x],a1[i].v);
            dfs1(y,x);
        }
    }
    inline int lca(int a,int b)
    {
        if(deep[a]>deep[b]) swap(a,b);
        for(int i=20;i>=0;--i) 
            if(deep[f[b][i]]>=deep[a]) b=f[b][i];
        if(a==b) return a;
        for(int i=20;i>=0;--i)
            if(f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i];
        return f[a][0];        
    }
    inline bool cmp(int a,int b) {return dfn[a]<dfn[b];}
    inline ll dfs2(int x)
    {
        ll sum=0,dp;
        for(int i=link2[x];i;i=a2[i].next)
        {
            int y=a2[i].y;
            sum+=dfs2(y);
        }
        if(vis[x]) dp=minv[x];
        else dp=min(minv[x],sum);
        if(vis[x]) vis[x]=false;
        link2[x]=0;
        return dp;
    }
    int main()
    {
    //    freopen("1.in","r",stdin);
        n=read();
        for(int i=1;i<n;++i)
        {
            int x=read(),y=read(),v=read();
            add1(x,y,v);add1(y,x,v);
        }
        minv[1]=1e18;
        dfs1(1,0);q=read();
        while(q--)
        {
            num=read();
            for(int i=1;i<=num;++i)
            {
                b[i]=read();
                vis[b[i]]=true;
            }
            sort(b+1,b+num+1,cmp);
            stak[top=1]=b[1];
            for(int i=2;i<=num;++i)
            {
                int now=b[i];
                int lc=lca(now,stak[top]);
                while(1)
                {
                    if(deep[lc]>=deep[stak[top-1]])
                    {
                        if(lc!=stak[top]) 
                        {
                            add2(lc,stak[top]);
                            if(lc!=stak[top-1]) stak[top]=lc;
                            else top--;
                        }
                        break;
                    }
                    else {add2(stak[top-1],stak[top]);top--;}
                }
                stak[++top]=now;
            }
            while(--top) add2(stak[top],stak[top+1]);
            cout<<dfs2(stak[1])<<endl;
            tot2=0;
        }
        return 0;
    }
    View Code

    [HEOI2014]大工程

    这种题真的一搞一上午啊,还是我太菜了.....

    我们看到k的范围自然就想到了虚树.

    那就让我们先考虑普通的dp:

    第一问,是所有关键点两两匹配的总长度之和.二三问分别是最长和最小长度.

    第一问直接统计每条边的贡献,第二三问用求直径的思想。

    我们设sum[x],mx[x],mn[x],size[x]分别表示以x为根的树中,所有关键点到x的路径和,最大值,最小值,和个数.

    对于ans1,我们考虑当前处理到y这个儿子.

    ans1+=sum[x]*size[y]+(sum[y]+dis(x,y)*size[y])*size[x].这个意思就是之前的子树中每条边都出来与y中的子树匹配.

    mx,与mn就不加述说了.

    我之前一直在思考如果是关键点的话,怎么特殊处理.因为我们的做法其实枚举了每一个lca,将两端拼接起来的.

    可是观察上面的转移,如果我们将关键点的size[x]初始化为1,那size[x]里就为累计一下(sum[y]+dis(x,y)的代价,其实就等同于x与所有关键点的匹配.

    在普通树里,dis(x,y)是1,而在虚树里dis(x,y)是deep[y]-deep[x]。之后将其转移即可.

    #include<bits/stdc++.h>
    #define ll long long
    using namespace std;
    const int N=1000010;
    int n,q,link1[N],tot1,link2[N],tot2,deep[N],f[N][25],b[N],num,dfn[N];
    int stak[N],top;
    ll ans1,ans2,ans3,sum[N],mx[N],mn[N],size[N];
    bool vis[N];
    struct edge{int y,next;}a1[N<<1],a2[N<<1]; 
    inline int read()
    {
        int x=0,ff=1;
        char ch=getchar();
        while(!isdigit(ch)) {if(ch=='-') ff=-1;ch=getchar();}
        while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
        return x*ff;
    }
    inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
    inline void add1(int x,int y)
    {
        a1[++tot1].y=y;
        a1[tot1].next=link1[x];
        link1[x]=tot1;
    }
    inline void add2(int x,int y)
    {
        a2[++tot2].y=y;
        a2[tot2].next=link2[x];
        link2[x]=tot2;
    }
    inline void dfs1(int x)
    {
        dfn[x]=++num;
        for(int i=link1[x];i;i=a1[i].next)
        {
            int y=a1[i].y;
            if(y==f[x][0]) continue;
            deep[y]=deep[x]+1;
            f[y][0]=x;
            for(int j=1;j<=20;++j) f[y][j]=f[f[y][j-1]][j-1];
            dfs1(y);
        }
    }
    inline int lca(int a,int b)
    {
        if(deep[a]>=deep[b]) swap(a,b);
        for(int i=20;i>=0;--i)
            if(deep[f[b][i]]>=deep[a]) b=f[b][i];
        if(a==b) return a;
        for(int i=20;i>=0;--i)
            if(f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i];
        return f[a][0];        
    }
    inline void dfs2(int x)
    {
        sum[x]=0;mx[x]=0;mn[x]=(vis[x]?0:1e18);size[x]=(vis[x]?1:0);
        for(int i=link2[x];i;i=a2[i].next)
        {
            int y=a2[i].y;
            dfs2(y);
            ll dis=deep[y]-deep[x];
            ans1+=(sum[y]+dis*size[y])*size[x]+sum[x]*size[y];
            ans2=max(ans2,mx[x]+mx[y]+dis);
            ans3=min(ans3,mn[x]+mn[y]+dis);
            sum[x]+=sum[y]+dis*size[y];
            mx[x]=max(mx[x],mx[y]+dis);
            mn[x]=min(mn[x],mn[y]+dis);
            size[x]+=size[y];
        }
        if(vis[x]) vis[x]=false;
        link2[x]=0;
    }
    int main()
    {
        freopen("1.in","r",stdin);
        n=read();
        for(int i=1;i<n;++i)
        {
            int x=read(),y=read();
            add1(x,y);add1(y,x);
        }
        deep[1]=1;dfs1(1);
        q=read();
        for(int i=1;i<=q;++i)
        {
            num=read();
            for(int j=1;j<=num;++j) b[j]=read(),vis[b[j]]=true;
            sort(b+1,b+num+1,cmp);
            stak[top=1]=b[1];
            for(int j=2;j<=num;++j)
            {
                int now=b[j];
                int lc=lca(now,stak[top]);
                while(1)
                {
                    if(deep[lc]>=deep[stak[top-1]])
                    {
                        if(lc!=stak[top])
                        {
                            add2(lc,stak[top]);
                            if(lc!=stak[top-1]) stak[top]=lc;
                            else --top;
                        }
                        break;
                    }
                    else {add2(stak[top-1],stak[top]);top--;}
                }
                stak[++top]=now;
            }
            while(--top) add2(stak[top],stak[top+1]);
            ans1=0;ans2=0;ans3=1e18;
            dfs2(stak[1]);
            printf("%lld %lld %lld
    ",ans1,ans3,ans2);
            tot2=0; 
        }
        return 0;
    }
    View Code
  • 相关阅读:
    开源IDS系列--解决barnyard2 停止运行 libmysqlclient.so.16.0.0
    开源IDS系列--snorby 2.6.2 undefined method `run_daily_report' for Event:Class (NoMethodError)
    开源IDS系列--snorby 进程正常,但是worker无法启动 The Snorby worker is not currently running
    大数据之路:阿里巴巴大数据实践小记
    LRU算法的应用
    Bitmap的巧用
    impala和presto
    常用sql
    wget rpm yum
    WSGI uwsgi uWSGI
  • 原文地址:https://www.cnblogs.com/gcfer/p/12491427.html
Copyright © 2011-2022 走看看