zoukankan      html  css  js  c++  java
  • 洛谷 5291 [十二省联考2019]希望(52分)——思路+树形DP

    题目:https://www.luogu.org/problemnew/show/P5291

    考场上写了 16 分的。不过只得了 4 分。

    对于一个救援范围,其中合法的点集也是一个连通块。 2n 枚举一个救援范围,然后换根 DP 一下范围内的每个点开始的最长链,那些最长链 <=L 的点就是该范围的合法点集。

    这样得到每个合法点集出现的方案, 与卷积 k 次即可。卷积的时候先 FWT 成点值,然后快速幂一样乘 k 次,再 FWT 回来即可。

    但只有 4 分。过不了大样例。

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define ll long long
    using namespace std;
    int rdn()
    {
      int ret=0;bool fx=1;char ch=getchar();
      while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
      while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
      return fx?ret:-ret;
    }
    int Mx(int a,int b){return a>b?a:b;}
    int Mn(int a,int b){return a<b?a:b;}
    const int N=1e6+5,mod=998244353;
    int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
    
    int n,L,k,hd[N],xnt,to[N<<1],nxt[N<<1];
    void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
    namespace S1{
      const int K=20,M=(1<<16)+5;
      int bin[K],dp[K],pr[K],sc[K],nd[K],tot;
      int ts,len,f[M],g[M]; bool vis[K],col[K];
      void chk_dfs(int cr,int fa)
      {
        vis[cr]=1;
        for(int i=hd[cr],v;i;i=nxt[i])
          if(col[v=to[i]]&&v!=fa)chk_dfs(v,cr);
      }
      void dfs(int cr,int fa)
      {
        dp[cr]=0;
        for(int i=hd[cr],v;i;i=nxt[i])
          if(col[v=to[i]]&&v!=fa)
        dfs(v,cr), dp[cr]=Mx(dp[cr],dp[v]+1);
      }
      void dfsx(int cr,int fa,int tmp)
      {
        if(Mx(dp[cr],tmp)<=L)ts|=bin[cr-1];
        int l=tot;
        for(int i=hd[cr],v;i;i=nxt[i])
          if(col[v=to[i]]&&v!=fa) nd[++tot]=v;
        int r=tot; if(l==r)return;
        pr[l+1]=dp[nd[l+1]]+1;
        for(int i=l+2;i<=r;i++)pr[i]=Mx(pr[i-1],dp[nd[i]]+1);
        sc[r]=dp[nd[r]]+1;
        for(int i=r-1;i>l;i--)sc[i]=Mn(sc[i+1],dp[nd[i]]+1);
        for(int i=l+1;i<=r;i++)
          {
        int tp=tmp;//=tmp
        if(i>l+1)tp=pr[i-1];if(i<r)tp=Mx(tp,sc[i+1]);
        dfsx(nd[i],cr,tp+1);
          }
      }
      void fwt(int *a,bool fx)
      {
        for(int R=2;R<=len;R<<=1)
          for(int i=0,m=R>>1;i<len;i+=R)
        for(int j=0;j<m;j++)
          {
            if(!fx)a[i+j]=upt(a[i+j]+a[i+m+j]);
            else a[i+j]=upt(a[i+j]-a[i+m+j]);
          }
      }
      void solve()
      {
        bin[0]=1;
        for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1;
        for(int s=1;s<bin[n];s++)
          {
        for(int i=1;i<=n;i++)
          {
            vis[i]=0;
            if(s&bin[i-1])col[i]=1; else col[i]=0;
          }
        int cr=0;
        for(int i=1;i<=n;i++)
          if(col[i]){chk_dfs(i,0);cr=i;break;}
        bool fg=0;
        for(int i=1;i<=n;i++)
          if(col[i]&&!vis[i]){fg=1;break;}
        if(fg)continue;
        ts=tot=0; dfs(cr,0); dfsx(cr,0,0);
        if(ts){f[ts]++; g[ts]++;}
          }
        k--; len=bin[n]; fwt(g,0); fwt(f,0);
        while(k)
          {
        if(k&1)
          {
            for(int i=0;i<len;i++)f[i]=(ll)f[i]*g[i]%mod;
          }
        for(int i=0;i<len;i++)g[i]=(ll)g[i]*g[i]%mod;
        k>>=1;
          }
        int ans=0; fwt(f,1);
        for(int s=1;s<bin[n];s++)ans=upt(ans+f[s]);
        printf("%d
    ",ans);
      }
    }
    int main()
    {
      freopen("hope.in","r",stdin);
      freopen("hope.out","w",stdout);
      n=rdn();L=rdn();k=rdn();
      for(int i=1,u,v;i<n;i++)
        u=rdn(),v=rdn(),add(u,v),add(v,u);
      if(n<=16){S1::solve();return 0;}
      return 0;
    }

    后来发现两个地方写错了:

    1.换根的时候做了前缀 max 和后缀 max ,其中后缀取 max 写成取 min 了;

    2.往孩子换根的时候用了一个 tp 对父亲来的 tmp 、前缀 max 、后缀 max 取 max ,结果 tp=tmp 之后写成 tp = pr[ ] 而非 tp = Mx( tp , pr[ ] ) 。

    改了这两个地方就有 16 分了。

    希望以后写代码的时候更仔细。别走神或不集中之类的。

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define ll long long
    using namespace std;
    int rdn()
    {
      int ret=0;bool fx=1;char ch=getchar();
      while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
      while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
      return fx?ret:-ret;
    }
    int Mx(int a,int b){return a>b?a:b;}
    int Mn(int a,int b){return a<b?a:b;}
    const int N=1e6+5,mod=998244353;
    int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
    
    int n,L,k,hd[N],xnt,to[N<<1],nxt[N<<1];
    void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
    namespace S1{
      const int K=20,M=(1<<16)+5;
      int bin[K],dp[K],pr[K],sc[K],nd[K],tot;
      int ts,len,f[M],g[M]; bool vis[K],col[K];
      void chk_dfs(int cr,int fa)
      {
        vis[cr]=1;
        for(int i=hd[cr],v;i;i=nxt[i])
          if(col[v=to[i]]&&v!=fa)chk_dfs(v,cr);
      }
      void dfs(int cr,int fa)
      {
        dp[cr]=0;
        for(int i=hd[cr],v;i;i=nxt[i])
          if(col[v=to[i]]&&v!=fa)
        dfs(v,cr), dp[cr]=Mx(dp[cr],dp[v]+1);
      }
      void dfsx(int cr,int fa,int tmp)
      {
        if(Mx(dp[cr],tmp)<=L)ts|=bin[cr-1];
        int l=tot;
        for(int i=hd[cr],v;i;i=nxt[i])
          if(col[v=to[i]]&&v!=fa) nd[++tot]=v;
        int r=tot; if(l==r)return;
        pr[l+1]=dp[nd[l+1]]+1;
        for(int i=l+2;i<=r;i++)pr[i]=Mx(pr[i-1],dp[nd[i]]+1);
        sc[r]=dp[nd[r]]+1;
        for(int i=r-1;i>l;i--)sc[i]=Mx(sc[i+1],dp[nd[i]]+1);////mx not mn!!!
        for(int i=l+1;i<=r;i++)
          {
        int tp=tmp;//=tmp
        if(i>l+1)tp=Mx(tp,pr[i-1]);if(i<r)tp=Mx(tp,sc[i+1]);//mx!!!
        dfsx(nd[i],cr,tp+1);
          }
      }
      void fwt(int *a,bool fx)
      {
        for(int R=2;R<=len;R<<=1)
          for(int i=0,m=R>>1;i<len;i+=R)
        for(int j=0;j<m;j++)
          {
            if(!fx)a[i+j]=upt(a[i+j]+a[i+m+j]);
            else a[i+j]=upt(a[i+j]-a[i+m+j]);
          }
      }
      void solve()
      {
        bin[0]=1;
        for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1;
        for(int s=1;s<bin[n];s++)
          {
        for(int i=1;i<=n;i++)
          {
            vis[i]=0;
            if(s&bin[i-1])col[i]=1; else col[i]=0;
          }
        int cr=0;
        for(int i=1;i<=n;i++)
          if(col[i]){chk_dfs(i,0);cr=i;break;}
        bool fg=0;
        for(int i=1;i<=n;i++)
          if(col[i]&&!vis[i]){fg=1;break;}
        if(fg)continue;
        ts=tot=0; dfs(cr,0); dfsx(cr,0,0);
        if(ts){f[ts]++; g[ts]++;}
          }
        k--; len=bin[n]; fwt(g,0); fwt(f,0);
        while(k)
          {
        if(k&1)
          {
            for(int i=0;i<len;i++)f[i]=(ll)f[i]*g[i]%mod;
          }
        for(int i=0;i<len;i++)g[i]=(ll)g[i]*g[i]%mod;
        k>>=1;
          }
        int ans=0; fwt(f,1);
        for(int s=1;s<bin[n];s++)ans=upt(ans+f[s]);
        printf("%d
    ",ans);
      }
    }
    int main()
    {
      freopen("hope.in","r",stdin);
      freopen("hope.out","w",stdout);
      n=rdn();L=rdn();k=rdn();
      for(int i=1,u,v;i<n;i++)
        u=rdn(),v=rdn(),add(u,v),add(v,u);
      if(n<=16){S1::solve();return 0;}
      return 0;
    }
    View Code

    然后参照题解写了 52 分的。

    很重要的转化是令 ( f[i] ) 表示 i 是合法点的救援范围个数,那么 k 个救援范围包含 i 的方案就是 ( f[i]^k ) ;考虑到一个方案的合法点集是连通块,即点数比边数大一,所以令 ( g[i] ) 表示边 i 的两端点是合法点的救援范围个数,答案就是 ( sumlimits_{i=1}^{n}f[i]^k - sumlimits_{i=1}^{n-1}g[i]^k ) 。

    然后就可以写 n*L 的 DP 了。再把链和 L=n 的部分做一下就有 52 分。

    不太会 k=1 时候的长链剖分。

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define ll long long
    using namespace std;
    int rdn()
    {
      int ret=0;bool fx=1;char ch=getchar();
      while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
      while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
      return fx?ret:-ret;
    }
    int Mx(int a,int b){return a>b?a:b;}
    int Mn(int a,int b){return a<b?a:b;}
    const int N=1e6+5,mod=998244353;
    int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
    int pw(int x,int k)
    {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}
    
    int n,L,k,hd[N],xnt=1,to[N<<1],nxt[N<<1],rd[N],f[N],g[N];
    void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;rd[y]++;}
    namespace S1{
      const int N=1005;
      int dfs(int cr,int fa,int lm)
      {
        int ret=1; if(!lm)return ret;
        for(int i=hd[cr],v;i;i=nxt[i])
          if((v=to[i])!=fa)
        ret=(ll)ret*(dfs(v,cr,lm-1)+1)%mod;
        return ret;
      }
      void dfsx(int cr,int fa)
      {
        for(int i=hd[cr],v;i;i=nxt[i])
          if((v=to[i])!=fa)
        {
          int ret=dfs(cr,v,L-1);
          ret=(ll)ret*dfs(v,cr,L-1)%mod;
          g[i>>1]=ret; dfsx(v,cr);
        }
      }
      void solve()
      {
        for(int i=1;i<=n;i++) f[i]=dfs(i,0,L);
        dfsx(1,0); int ans=0;
        for(int i=1;i<=n;i++)ans=upt(ans+pw(f[i],k));
        for(int i=1;i<n;i++)ans=upt(ans-pw(g[i],k));
        printf("%d
    ",ans);
      }
    }
    namespace S2{
      const int N=1e5+5,M=105;
      int nd[N],tot;
      struct Node{
        int v[M],s[M],cd;
        void init(){v[0]=s[0]=1;}
        void frs()
        {
          for(int i=1;i<=cd;i++)
        s[i]=upt(s[i-1]+v[i]);
        }
        void cz()
        {
          cd=Mn(cd+1,L);
          for(int i=cd;i;i--)v[i]=v[i-1];
          frs();
        }
      }dp[N],pr[N],sc[N],up[N];
      void mrg(Node &d0,Node d1)
      {
        int yc=d0.cd, lm=d1.cd, tc=Mn(L,Mx(yc,lm+1));
        d0.cd=tc;
        for(int j=yc+1;j<=tc;j++)
          d0.v[j]=0, d0.s[j]=d0.s[yc];//0 not 1
        for(int j=1;j<=tc;j++)
          {
        int tp;
        if(j-1<=lm)tp=d1.s[j-1]; else tp=d1.s[lm];
        tp++;///for choosen't
        d0.v[j]=(ll)d0.v[j]*tp%mod;
        if(j-1<=lm)
          d0.v[j]=(d0.v[j]+(ll)d0.s[j-1]*d1.v[j-1])%mod;
          }
        d0.frs();
      }
      void mg2(Node &d0,Node d1)
      {
        int yc=d0.cd, lm=d1.cd, tc=Mn(L,Mx(yc,lm));
        d0.cd=tc;
        for(int j=yc+1;j<=tc;j++)
          d0.v[j]=0, d0.s[j]=d0.s[yc];//0 not 1
        for(int j=0;j<=tc;j++)
          {
        int tp;
        if(j<=lm)tp=d1.s[j]; else tp=d1.s[lm];
        tp++;///for choosen't
        d0.v[j]=(ll)d0.v[j]*tp%mod;
        if(j&&j<=lm)
          d0.v[j]=(d0.v[j]+(ll)d0.s[j-1]*d1.v[j])%mod;
          }
        d0.frs();
      }
      void dfs(int cr,int fa)
      {
        dp[cr].init();
        for(int i=hd[cr],v;i;i=nxt[i])
          if((v=to[i])!=fa)
        {
          dfs(v,cr);
          mrg(dp[cr],dp[v]);
        }
      }
      void dfsx(int cr,int fa)
      {
        int tp=up[cr].cd;
        f[cr]=(tp>=L?up[cr].s[L]:up[cr].s[tp]);
        tp=dp[cr].cd;
        f[cr]=(ll)f[cr]*(tp>=L?dp[cr].s[L]:dp[cr].s[tp])%mod;
        int l=tot;
        for(int i=hd[cr],v;i;i=nxt[i])
          if((v=to[i])!=fa)
        {
          nd[++tot]=i;
          if(tot==l+1)pr[tot].init();
          else pr[tot]=pr[tot-1];
          mrg(pr[tot],dp[v]);
        }
        int r=tot;
        for(int i=r;i>l;i--)
          {
        if(i==r)sc[i].init();
        else sc[i]=sc[i+1];
        mrg(sc[i],dp[to[nd[i]]]);
          }
        for(int i=l+1;i<=r;i++)
          {
        pr[i].v[0]=pr[i].s[0]=0;pr[i].frs();
        sc[i].v[0]=sc[i].s[0]=0;sc[i].frs();
          }
        for(int i=l+1;i<=r;i++)
          {
        int v=to[nd[i]],bh=nd[i]>>1;
        up[v]=up[cr];
        if(i>l+1) mg2(up[v],pr[i-1]);
        if(i<r) mg2(up[v],sc[i+1]);
        int tp=up[v].cd;
        g[bh]=(tp>=L-1?up[v].s[L-1]:up[v].s[tp]);
        tp=dp[v].cd;
        g[bh]=(ll)g[bh]*(tp>=L-1?dp[v].s[L-1]:dp[v].s[tp])%mod;
        up[v].cz();
        dfsx(v,cr);
          }
      }
      void solve()
      {
        dfs(1,0); up[1].init(); dfsx(1,0);
        int ans=0;
        for(int i=1;i<=n;i++)
          ans=upt(ans+pw(f[i],k));
        for(int i=1;i<n;i++)
          ans=upt(ans-pw(g[i],k));
        printf("%d
    ",ans);
      }
    }
    namespace S3{
      const int N=2e5+5;
      int dp[N],nd[N],pr[N],sc[N],tot;
      void dfs(int cr,int fa)
      {
        dp[cr]=1;
        for(int i=hd[cr],v;i;i=nxt[i])
          if((v=to[i])!=fa)
        {
          dfs(v,cr); dp[cr]=(ll)dp[cr]*(dp[v]+1)%mod;
        }
      }
      void dfsx(int cr,int fa,int tmp)
      {
        f[cr]=(ll)dp[cr]*(tmp+1)%mod;
        int l=tot;
        for(int i=hd[cr],v;i;i=nxt[i])
          if((v=to[i])!=fa)
        {
          nd[++tot]=i;
          if(tot==l+1)pr[tot]=1;
          else pr[tot]=pr[tot-1];
          pr[tot]=(ll)pr[tot]*(dp[v]+1)%mod;
        }
        int r=tot;
        for(int i=r;i>l;i--)
          {
        if(i==r)sc[i]=1;
        else sc[i]=sc[i+1];
        sc[i]=(ll)sc[i]*(dp[to[nd[i]]]+1)%mod;
          }
        for(int i=l+1;i<=r;i++)
          {
        int v=to[nd[i]], tp=tmp+1, bh=nd[i]>>1;
        if(i>l+1)tp=(ll)tp*pr[i-1]%mod;
        if(i<r)tp=(ll)tp*sc[i+1]%mod;
        g[bh]=(ll)tp*dp[v]%mod;
        dfsx(v,cr,tp);
          }
      }
      void solve()
      {
        dfs(1,0); dfsx(1,0,0); int ans=0;
        for(int i=1;i<=n;i++)
          ans=upt(ans+pw(f[i],k));
        for(int i=1;i<n;i++)
          ans=upt(ans-pw(g[i],k));
        printf("%d
    ",ans);
      }
    }
    namespace S4{
      void solve()
      {
        int ans=0;
        for(int i=1;i<=n;i++)
          {
        int t0=Mn(L+1,i), t1=Mn(L+1,n-i+1);
        ans=upt(ans+pw((ll)t0*t1%mod,k));
          }
        for(int i=1;i<n;i++)
          {
        int t0=Mn(L,i), t1=Mn(L,n-i);
        ans=upt(ans-pw((ll)t0*t1%mod,k));
          }
        printf("%d
    ",ans);
      }
    }
    int main()
    {
      n=rdn();L=rdn();k=rdn();
      for(int i=1,u,v;i<n;i++)
        { u=rdn();v=rdn();add(u,v);add(v,u);}
      if(n<=1000){S1::solve();return 0;}
      if((ll)n*L<=1e7){S2::solve();return 0;}
      if(L==n){S3::solve();return 0;}
      bool fg=0;
      for(int i=1;i<=n;i++)if(rd[i]>2){fg=1;break;}
      if(!fg){S4::solve();return 0;}
      return 0;
    }
    View Code
  • 相关阅读:
    spring boot 启动后执行初始化方法
    Linux CentOS 7 下 JDK 安装与配置
    Linux rpm 命令参数使用详解[介绍和应用]
    异常处理: 重载Throwable.fillInStackTrace方法已提高Java性能
    dubbo 配置属性
    centos7 操作防火墙
    springBoot 打包 dubbo jar包
    直播中聊天场景的用例分享
    解决在安装Fiddler4.6版本后,在手机上安装证书出现的问题解决方法
    系统调优方案思路分享
  • 原文地址:https://www.cnblogs.com/Narh/p/10679616.html
Copyright © 2011-2022 走看看