zoukankan      html  css  js  c++  java
  • 树上差分学习笔记

    树上差分利用前缀和的思想,利用树上的前缀和(也就是子树和),记录树上的一些信息,因为它可以进行离线操作,复杂度O(n),时间、空间、代码复杂度都十分优秀。

    最大流
    FJ给他的牛棚的N(2≤N≤50,000)个隔间之间安装了N-1根管道,隔间编号从1到N。所有隔间都被管道连通了。
    FJ有K(1≤K≤100,000)条运输牛奶的路线,第i条路线从隔间si运输到隔间ti。一条运输路线会给它的两个端点处的隔间以及中间途径的所有隔间带来一个单位的运输压力,你需要计算压力最大的隔间的压力是多少。

    这道题需要让我们求出每个点的覆盖次数(这显然可以用树链剖分来做),但这也是一个树上差分的经典题。

    抛开这道题,先考虑对于一维的一个空间,我们如何差分,没错,对于区间(l-r)来说,在l出加上1,在r+1处减去1,再求前缀和即可。

    那在树上如何操作,我们可以把树抽象的想象成自下而上的一个数组,对于一段连续的链,在下面的端点加上1,在上面的端点的父节点减去1,求子树和。

    那么左右端点不在一条直链上怎么办?可以考虑使用lCA,对于两个点的LCA来说,LCA才这条链上只出现了一次,所以在两个端点分别加1,LCA-1,LCA的父亲减1,求子树和,这道题就水过了

     

    #include<iostream>
    #include<cstdio>
    #define N 50009
    using namespace std;
    int ans,head[N],ji[N],tot,deep[N],p[N][22],n,m,a,b,c,fa[N];
    struct de
    {
        int n,to;
    }an[N<<1];
    inline void add(int u,int v)
    {
        an[++tot].n=head[u];
        an[tot].to=v;
        head[u]=tot;
    }
    void dfs(int u,int f)
    {
        deep[u]=deep[f]+1;
        fa[u]=f;
        p[u][0]=f;
        for(int i=1;(1<<i)<=deep[u];++i)
          p[u][i]=p[p[u][i-1]][i-1];
        for(int i=head[u];i;i=an[i].n)
        if(an[i].to!=f)
        {
            int v=an[i].to;
            dfs(v,u);
        }
    }
    inline int getlca(int a,int b)
    {
        if(deep[a]<deep[b])swap(a,b);
        for(int i=20;i>=0;--i)
          if(deep[a]-(1<<i)>=deep[b])a=p[a][i];
        if(a==b)return b;
        for(int i=20;i>=0;--i)
          if(p[a][i]!=p[b][i])a=p[a][i],b=p[b][i];
        return p[a][0];
    } 
    void dfs2(int u)
    {
        for(int i=head[u];i;i=an[i].n)
          if(an[i].to!=fa[u])
          {
              int v=an[i].to;
              dfs2(v);
              ji[u]+=ji[v];
          }
        ans=max(ans,ji[u]);
    }
    int main()
    {
        scanf("%d%d",&n,&m);
        for(int i=1;i<n;++i)
          scanf("%d%d",&a,&b),add(a,b),add(b,a);
        dfs(1,0);
        for(int i=1;i<=m;++i)
        {
            scanf("%d%d",&a,&b);
            c=getlca(a,b);
            ji[c]--;
            ji[fa[c]]--;
            ji[a]++;ji[b]++;
        }
        dfs2(1);
        cout<<ans;
        return 0;
    }

    再来道难一点的。

    NOIP2015运输计划
    公元 20442044 年,人类进入了宇宙纪元。
    L 国有 n 个星球,还有 n-1 条双向航道,每条航道建立在两个星球之间,这 n-1 条航道连通了 L国的所有星球。
    小 P 掌管一家物流公司, 该公司有很多个运输计划,每个运输计划形如:有一艘物流飞船需要从 u 号星球沿最快的宇航路径飞行到 v 号星球去。显然,飞船驶过一条航道是需要时间的,对于航道 j ,任意飞船驶过它所花费的时间为 t ,并且任意两艘飞船之间不会产生任何干扰。
    为了鼓励科技创新, L国国王同意小 P的物流公司参与 L 国的航道建设,即允许小 P 把某一条航道改造成虫洞,飞船驶过虫洞不消耗时间。
    在虫洞的建设完成前小 P 的物流公司就预接了 m 个运输计划。在虫洞建设完成后,这 m 个运输计划会同时开始,所有飞船一起出发。当这 m 个运输计划都完成时,小 P 的物流公司的阶段性工作就完成了。
    如果小 P可以自由选择将哪一条航道改造成虫洞, 试求出小 P 的物流公司完成阶段性工作所需要的最短时间是多少?

    问题来了,刚才我们要求点的覆盖次数,这回要求边,我们都知道求点的覆盖可以用树链剖分来做,
    但把点换成边好像无从下手(不过好像也可以做。但复杂度不是太好看

    那么我们该如何处理?
    还是树上差分,但我们把定义换一下,把边上的信息放到点上,每个子树记录的是这个点向上连的那条边出现的次数,左右端点分别加1,LCA减2,就可以搞了。

    然后咧?

    我们发现直接求很困难,所以就考虑把求最值改成二分答案+验证找最值,先二分最终的答案,那么大于这个答案的链是需要删边的,在把需要删边的链来一波差分,找出这些链的最长公共链,判断把这条链删掉之后答案是否合法,然后这题就做完了

      

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #define N 300009
    #define R register
    using namespace std;
    int tot,head[N],ji[N],jii[N],num,ll,lca[N],u[N],v[N],n,deep[N],d[N],fa[N],m,L[N],p[N][22];
    int a,b,c,l,r;
    struct de
    {
        int n,to,l;
    }an[N<<1];
    inline void add(int u,int v,int l)
    {
        an[++tot].n=head[u];
        an[tot].to=v;
        head[u]=tot;
        an[tot].l=l;
    }
    void dfs2(int u,int fa)
    {
        for(R int i=head[u];i;i=an[i].n)
          if(an[i].to!=fa)
          {
              int v=an[i].to;
              dfs2(v,u);
              ji[u]+=ji[v];
              
          }
        if(ji[u]==num)ll=max(ll,jii[u]);
    }
    bool ch(int pos)
    {
        int pp=0;
        num=0;ll=0;
        memset(ji,0,sizeof(ji));
        for(R int i=1;i<=m;++i)
         if(L[i]>pos)
         {
             ji[lca[i]]-=2;
             ji[u[i]]++;
             ji[v[i]]++;
             pp=max(pp,L[i]);
             num++;
         }
         dfs2(1,0);
         if(pp-ll<=pos)return 1;
         else return 0;
    }
    void dfs(int u,int f)
    {
        fa[u]=f;
        deep[u]=deep[f]+1;
        p[u][0]=f;
        for(R int i=1;(1<<i)<=deep[u];++i)
          p[u][i]=p[p[u][i-1]][i-1];
        for(R int i=head[u];i;i=an[i].n)
        if(an[i].to!=f)
        {
            int v=an[i].to;
            d[v]=d[u]+an[i].l;
            jii[v]=an[i].l;
            dfs(v,u);
        }
    }
    inline int getlca(int a,int b)
    {
        if(deep[a]<deep[b])swap(a,b);
        for(R int i=20;i>=0;--i)
          if(deep[a]-(1<<i)>=deep[b])a=p[a][i];
        if(a==b)return b;
        for(R int i=20;i>=0;--i)
          if(p[a][i]!=p[b][i])a=p[a][i],b=p[b][i];
        return p[a][0];
    } 
    int rd()
    {
        int x=0;
        char c=getchar();
        while(!isdigit(c))c=getchar();
        while(isdigit(c))
        {
            x=(x<<1)+(x<<3)+(c^48);
            c=getchar(); 
        }
        return x;
    }
    int main()
    {
        n=rd();m=rd();
        for(R int i=1;i<n;++i)
          a=rd(),b=rd(),c=rd(),add(a,b,c),add(b,a,c);
        dfs(1,0);
        for(R int i=1;i<=m;++i)
        {
          u[i]=rd();v[i]=rd(); 
          lca[i]=getlca(u[i],v[i]);
          L[i]=d[u[i]]+d[v[i]]-2*d[lca[i]];
          r=max(r,L[i]);
       }
       int ans=0;
       while(l<=r)
       {
           int mid=(l+r)>>1;
           if(ch(mid))
           {
               ans=mid;
               r=mid-1;
           }
           else l=mid+1;
       }
        cout<<ans;
        return 0;
    }


    来看最后一道 NOIP2016 天天爱跑步


    这是我见过的最难的一道树上差分题目,它的解法十分的巧妙。

    在第w[j]秒观察到,难道我还让它动态的往上跳吗?

    当然不用,让我们列一波式子。

    因为这是一颗树,它具有一些非常有用(恶心)的性质,就是链上的LCA,当w在s到LCA的路上时,有以下式子成立

    deep[S]=w[x]+deep[x]

    当w在LCA到t的路上时,有以下式子成立
    deep[s]-2*d[lca(s,t)]=w[x]-deep[x]

    由于二式不等价,所以我们要分开处理,由于在S到LCA的路上一式成立,在LCA到T时二式成立,但我们在处理两种情况时要注意不能重复。

    所以我们开两个映射,第一个表示在一式情况下左边的式子的结果对应了几个x,第二个表示二式下左边的式子的结果对应了几个x。

    然后怎么做?

    还考虑树上差分,我们可以理解为又有一种数在s出现,在lca的父亲处消失,另一种树在t出现,在LCA处消失,这两种数对应了上述两种式子,那我们遍历整棵树时,到达一个节点就把这个位置对应的结果加入映射,对于每个询问的答案,就是遍历以这个点为根的子树前后右边的式子的结果对应的映射的差。

    #include<iostream>
    #include<cstdio>
    #include<vector>
    #include<map>
    #define N 300009
    using namespace std;
    map<int,int>A,B;
    struct pai
    {
        int tag,tag2,num;
    };
    vector<pai>ji[N];
    int n,m,head[N],p[N][22],deep[N],fa[N],tot,a,b,w[N],ans[N];
    struct dwd
    {
        int n,to;
    }an[N<<1];
    inline void add(int u,int v)
    {
        an[++tot].n=head[u];
        an[tot].to=v;
        head[u]=tot;
    }
    void dfs(int u,int f)
    {
        fa[u]=f;
        deep[u]=deep[f]+1;
        p[u][0]=f;
        for(int i=1;(1<<i)<=deep[u];++i)
          p[u][i]=p[p[u][i-1]][i-1];
        for(int i=head[u];i;i=an[i].n)
        {
            int v=an[i].to;
            if(v!=f)dfs(v,u);
        }
    }
    inline int getlca(int a,int b)
    {
        if(deep[a]<deep[b])swap(a,b);
        for(int i=20;i>=0;--i)
          if(deep[a]-(1<<i)>=deep[b])a=p[a][i];
        if(a==b)return b;
        for(int i=20;i>=0;--i)
          if(p[a][i]!=p[b][i])a=p[a][i],b=p[b][i];
        return p[a][0];
    } 
    void dfs2(int u,int fa)
    {   
        int p=A[deep[u]+w[u]],q=B[w[u]-deep[u]];//gai
        for(int i=head[u];i;i=an[i].n)
         if(an[i].to!=fa)
        {
            int v=an[i].to;
            dfs2(v,u);
        }        
       for(int i=0;i<ji[u].size();++i)
        {
          if(ji[u][i].tag==1)A[ji[u][i].num]+=ji[u][i].tag2;
          else B[ji[u][i].num]+=ji[u][i].tag2;
        }
        ans[u]=B[w[u]-deep[u]]+A[deep[u]+w[u]]-q-p;
    }
    int main()
    {
        scanf("%d%d",&n,&m);
        for(int i=1;i<n;++i)
        scanf("%d%d",&a,&b),add(a,b),add(b,a);
       dfs(1,0);
       for(int i=1;i<=n;++i)
         scanf("%d",&w[i]);
        for(int i=1;i<=m;++i)
          {
              int s,t,lca;
              scanf("%d%d",&s,&t);
              lca=getlca(s,t);
              ji[s].push_back(pai{1,1,deep[s]});
              ji[fa[lca]].push_back(pai{1,-1,deep[s]});
              ji[t].push_back(pai{2,1,deep[s]-2*deep[lca]});
              ji[lca].push_back(pai{2,-1,deep[s]-2*deep[lca]});
          }
          dfs2(1,0);
          for(int i=1;i<=n;++i)
            printf("%d ",ans[i]);
        return 0;
    }


  • 相关阅读:
    盘点Spring Boot最核心的27个注解
    一份非常完整的 MySQL 规范
    一份非常完整的 MySQL 规范
    Restful API 中的错误处理方案
    Restful API 中的错误处理方案
    一文总结 CPU 基本知识
    RPM软件管理工具
    yum仓库配置
    spring配置和下载
    spring的IOC 的底层实现原理
  • 原文地址:https://www.cnblogs.com/ZH-comld/p/9384690.html
Copyright © 2011-2022 走看看