zoukankan      html  css  js  c++  java
  • 洛谷3233 HNOI2014(虚树+dp)

    膜拜一发(mts\_246,forever\_shi)

    这两位爷是真的无敌!

    首先来看这个题,一看题目的数据范围和“关键点”字眼,我们就能得知这是一道虚树题

    那就先一如既往的建出来虚树吧
    QWQ
    但是这之后,应该怎么去dp呢。

    首先,我们需要知道在虚树上每个点的从属都是谁,这样才便于我们进一步扩展到虚树之外的点。

    那么怎么求这个东西呢?我们可以先通过一编dfs,求出来子树对父亲的影响,也就是从下到上的答案(先(dfs)到底,再更新)

    void dp1(int x,int flag)
    {
     dis[x]=inf;
     bel[x]=0;
     ans[x]=0;
     if (tag[x]==flag)
     {
      dis[x]=0;
      bel[x]=x;
     }
     for (int i=point[x];i;i=nxt[i])
     {
      int p = to[i];
      int now = val[i];
      dp1(p,flag);
      if ((dis[x]>dis[p]+val[i]) || (dis[x]==dis[p]+val[i] && bel[x]>bel[p]))
      {
       dis[x]=dis[p]+val[i];
       bel[x]=bel[p];
      }
     }
    }
    

    然后呢,因为还存在说通过兄弟更新,或者子树之外的点更新的情况,所以我们还需要重新(dfs)一遍,不过这次是尝试通过用父亲来更新儿子,也就是从上到下(先更新,后(dfs)

    void dp2(int x,int flag)
    {
     for (int i=point[x];i;i=nxt[i])
     {
      int p = to[i];
      int now = val[i];
      if ((dis[p]>dis[x]+val[i]) || (dis[p]==dis[x]+val[i] && bel[p]>bel[x]))
      {
       dis[p]=dis[x]+val[i];
       bel[p]=bel[x];
      }
      dp2(p,flag);
     }
    }
    

    至此,我们就得到了所有虚树上的点的(dis)(bel),那怎么扩展到所有点呢QWQ

    这里就需要一个奇妙的统计答案的技巧了

    我们另(ymh[i])表示与(i)相同议事处的点的个数。

    首先,我们将初值弄成(size[i]),是i在原树的子树大小(这一定是不对的,因为子树中有一些会和他的某个非直系子辈给包含,而他在上面的一片区域,也一定有和他一样的点)

    然后我们进行dfs

    对于这个东西,显然是要从下向上更新的
    所以我们(dfs)到底,对于当前(x->p)这条边,如果说两个点的(bel)是相等的,我们就令(ymh[x]-=size[p]),相当于把原树(x->p)这路径附近部分所有的点,都给了(x),不论是合法还是不合法。

    那么上一种情况里面不合法的情况,就是两个点之间存在(bel)不一样的点,也就是说,会存在一条边(x->p),其中(bel[x]!=bel[p]),那么这条路径之间的东西应该怎么算呢。

    不难发现,一定是会存在说,这段路径中间会有一个点,以上全是属于(bel[x]),以下全是属于(bel[p])的。

    那么我们可以通过倍增的方式来求出这个点(具体求的时候有一些细节,直接写在代码里面了)

    然后假设求出来的点是(lyf),那么$$ymh[p]+=size[lyf]-size[p],ymh[x]-=size[lyf]$$

    原理的话,和上面同理

    这种用ymh数组求解的方式,实际上就是先弄一个初值,然后把不合法的(或者是会算重复的)减掉,然后把少算的加进去

    QWQ总之就是很巧妙!!!!!!!

    既不会算少,也不会算重复

    直接放代码

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cstring>
    #include<cmath>
    #include<queue>
    #include<map>
    #include<set>
    #define mk makr_pair
    #define ll long long
    using namespace std;
    inline int read()
    {
      int x=0,f=1;char ch=getchar();
      while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
      while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
      return x*f;
    }
    const int maxn = 6e5+1e2;
    const int maxm = 2*maxn;
    const int inf = 1e9;
    int point[maxn],nxt[maxm],to[maxm],val[maxm];
    int bel[maxn],dis[maxn],f[maxn][21];
    int num[maxn];
    int size[maxn],deep[maxn],dfn[maxn];
    int cnt,n,m;
    int tot,top;
    int s[maxn];
    int k,a[maxn];
    int ymh[maxn],tag[maxn];
    int ans[maxn];
    void addedge(int x,int y,int w)
    {
     nxt[++cnt]=point[x];
     to[cnt]=y;
     val[cnt]=w;
     point[x]=cnt;
    }
    void dfs(int x,int fa,int dep)
    {
     deep[x]=dep;
     dfn[x]=++tot;
        size[x]=1;
        for (int i=point[x];i;i=nxt[i])
        {
         int p = to[i];
         if (p==fa) continue;
      f[p][0]=x;
         dfs(p,x,dep+1);
         size[x]+=size[p];
     }
    }
    void init()
    {
     for (int j=1;j<=20;j++)
       for (int i=1;i<=n;i++)
       {
         f[i][j]=f[f[i][j-1]][j-1];
       }
    }
    int go_up(int x,int d)
    {
     for (int i=0;i<=20;i++)
     {
      if (d & (1<<i))
        x=f[x][i];
     }
     return x;
    }
    int lca(int x,int y)
    {
     if (deep[x]>deep[y]) x=go_up(x,deep[x]-deep[y]);
     else y=go_up(y,deep[y]-deep[x]); 
     if (x==y) return x;
     for (int i=20;i>=0;i--)
     {
      if (f[x][i]!=f[y][i])
      {
       x=f[x][i];
       y=f[y][i];
      }
     }
     return f[x][0];
    } 
    bool cmp(int a,int b)
    {
     return dfn[a]<dfn[b];
    }
    void solve()
    {
       sort(a+1,a+1+k,cmp);
       cnt=0;
       top=1;
       s[top]=1;
       for (int i=1;i<=k;i++)
       {
          int l = lca(s[top],a[i]);
          if (l!=s[top])
          {
            while (top>1)
       {
         if (dfn[s[top-1]]>dfn[l])
         {
          addedge(s[top-1],s[top],deep[s[top]]-deep[s[top-1]]);
          top--;
        }
        else
        {
         if (dfn[s[top-1]]==dfn[l])
            {
            addedge(s[top-1],s[top],deep[s[top]]-deep[s[top-1]]);
            top--;
            break;
           }
           else
           {
             addedge(l,s[top],deep[s[top]]-deep[l]);
             s[top]=l;
            break;
        }
        }
       } 
       }
       if (s[top]!=a[i]) s[++top]=a[i];
       }
       while (top>1)
       {
          addedge(s[top-1],s[top],deep[s[top]]-deep[s[top-1]]);
       top--;
       }
    }
    void dp1(int x,int flag)
    {
     dis[x]=inf;
     bel[x]=0;
     ans[x]=0;
     if (tag[x]==flag)
     {
      dis[x]=0;
      bel[x]=x;
     }
     for (int i=point[x];i;i=nxt[i])
     {
      int p = to[i];
      int now = val[i];
      dp1(p,flag);
      if ((dis[x]>dis[p]+val[i]) || (dis[x]==dis[p]+val[i] && bel[x]>bel[p]))
      {
       dis[x]=dis[p]+val[i];
       bel[x]=bel[p];
      }
     }
    }
    void dp2(int x,int flag)
    {
     for (int i=point[x];i;i=nxt[i])
     {
      int p = to[i];
      int now = val[i];
      if ((dis[p]>dis[x]+val[i]) || (dis[p]==dis[x]+val[i] && bel[p]>bel[x]))
      {
       dis[p]=dis[x]+val[i];
       bel[p]=bel[x];
      }
      dp2(p,flag);
     }
    }
    int up(int x,int d)
    {
     for (int i=20;i>=0;i--)
       if (deep[f[x][i]]>=d) x=f[x][i];
     return x;
    }
    void dodo(int x)
    {
     ymh[x]=size[x];
     for (int &i=point[x];i;i=nxt[i])
     {
      int p = to[i];
         dodo(p);
         if (bel[x]==bel[p]) ymh[x]-=size[p];
         else
         {
          int now = dis[p]+dis[x]+deep[p]-deep[x]-1; //这里减1的原因是为了后面方便一些,因为偶数的情况,中间那个点的归属不能够直接倍增的时候判断,所以我们需要在后面if的时候,特殊处理一下 
          int st = now/2-dis[p];  //这个是距离to的距离 
          int dd = deep[p]-st; //中间点的深度 
          int lyf = p;
          if(dd>=0) lyf=up(p,dd);
       if ((now&1) && bel[x]>bel[p] && st>=0) lyf = f[lyf][0]; //与上面那个减1相对应,判断中间点的归属 
       ymh[p]+=size[lyf]-size[p]; //把lyf底下的点,都给to 
       ymh[x]-=size[lyf];  //把转折点剩下的部分给fa,由于初值是整个的size,所以直接做减法就好 
      }
      ans[bel[p]]+=ymh[p];
     }
     if (x==1) ans[bel[x]]+=ymh[x];
    }
    int b[maxn];
    int main()
    {
      n=read();
      for (int i=1;i<n;i++)
      {
        int x=read(),y=read();
        addedge(x,y,1);
        addedge(y,x,1);
      }
      dfs(1,0,1);
      init();
      memset(point,0,sizeof(point));
      m=read();
      for (int i=1;i<=m;i++)
      {
        k=read();
        for (int j=1;j<=k;j++) a[j]=read(),tag[a[j]]=i,b[j]=a[j];
        solve();
        dp1(1,i);
        dp2(1,i);
        dodo(1);
        for (int j=1;j<=k;j++) cout<<ans[b[j]]<<" ";
        cout<<"
    ";
        for (int j=1;j<=k;j++) ans[b[j]]=0;
      }
      return 0;
    }
    
    
  • 相关阅读:
    异步初体验
    ASPNET登陆总结
    14年最后一天了
    个人阅读作业
    软工个人博客-week7
    软工结对编程作业-人员
    软工结对编程作业-(附加题)
    软工结对编程作业-(非附加题)
    个人博客作业Week3
    软工个人作业-博客作业-WEEK2
  • 原文地址:https://www.cnblogs.com/yimmortal/p/10161515.html
Copyright © 2011-2022 走看看