zoukankan      html  css  js  c++  java
  • LOJ 3055 「HNOI2019」JOJO—— kmp自动机+主席树

    题目:https://loj.ac/problem/3055

    先写了暴力。本来想的是 n<=300 的那个在树上暴力维护好整个字符串, x=1 的那个用主席树维护好字符串和 nxt 数组。但 x=1 的部分会 TLE ,而且似乎不太对的样子。

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<vector>
    #define ll long long
    #define pb push_back
    #define ls Ls[cr]
    #define rs Rs[cr]
    using namespace std;
    int rdn()
    {
      int ret=0;bool fx=1;char ch=getchar();
      while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
      while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
      return fx?ret:-ret;
    }
    int n;
    namespace S1{
      const int N=305,M=N*N;
      int fa[N],q[N],s[M],nxt[M]; ll sm[N];
      vector<int> c[N],nt[N];
      void add(int x,int y,int m=0,int ch=0)
      {
        sm[y]=sm[x]; fa[y]=x; if(!m)return;
        int top=0, cr=x;
        while(cr)q[++top]=cr,cr=fa[cr];
        int tot=0;
        for(int i=top;i;i--)
          {
        cr=q[i];
        for(int j=0,lm=c[cr].size();j<lm;j++)
          s[++tot]=c[cr][j], nxt[tot]=nt[cr][j];
          }
        c[y].resize(m); nt[y].resize(m);
        int i,j;
        if(!tot){ s[1]=c[y][0]=ch;i=2;j=2;} else { i=tot+1;j=1;}
        for(;j<=m;j++,i++)
          {
        s[i]=ch; cr=nxt[i-1];
        while(cr&&s[cr+1]!=ch)cr=nxt[cr];
        if(s[cr+1]==ch)nxt[i]=cr+1; else nxt[i]=0;
        c[y][j-1]=ch; nt[y][j-1]=nxt[i]; sm[y]+=nxt[i];
          }
      }
      void solve()
      {
        int op,x; char ch[5];
        for(int i=1;i<=n;i++)
          {
        op=rdn();x=rdn();
        if(op==1)
          { scanf("%s",ch); add(i-1,i,x,ch[0]-'a'+1);}
        else add(x,i);
        printf("%lld
    ",sm[i]);
          }
      }
    }
    namespace S2{
      const int N=1e5+5,M=2e6+5;
      int rt[N],tot,Ls[M],Rs[M],cd[N]; ll sm[N];
      struct Node{ int c,nxt;}a[M];
      int ins(int l,int r,int &cr,int pr,int p,int ch)
      {
        if(!cr){cr=++tot;ls=Ls[pr];rs=Rs[pr];}
        if(l==r){a[cr].c=ch;return cr;}
        int mid=l+r>>1;
        if(p<=mid)return ins(l,mid,ls,Ls[pr],p,ch);
        return ins(mid+1,r,rs,Rs[pr],p,ch);
      }
      Node qry(int l,int r,int cr,int p)
      {
        if(l==r)return a[cr]; int mid=l+r>>1;
        if(p<=mid)return qry(l,mid,ls,p);
        return qry(mid+1,r,rs,p);
      }
      void add(int cr,int pr,int m,int ch)
      {
        sm[cr]=sm[pr]; cd[cr]=cd[pr];
        for(int i=1,d;i<=m;i++)
          {
        cd[cr]++; d=ins(1,n,rt[cr],rt[pr],cd[cr],ch);
        int p=qry(1,n,rt[cr],cd[cr]-1).nxt;
        while(p&&qry(1,n,rt[cr],p+1).c!=ch)
          p=qry(1,n,rt[cr],p).nxt;
        if(p+1!=cd[cr]&&qry(1,n,rt[cr],p+1).c==ch)//!=
          a[d].nxt=p+1;
        else a[d].nxt=0;
        sm[cr]+=a[d].nxt;
          }
      }
      void solve()
      {
        int op,x; char ch[5];
        for(int i=1;i<=n;i++)
          {
        op=rdn();x=rdn();
        if(op==1)
          { scanf("%s",ch); add(i,i-1,x,ch[0]-'a'+1);}
        else {sm[i]=sm[x];rt[i]=rt[x];cd[i]=cd[x];}
        printf("%lld
    ",sm[i]);
          }
      }
    }
    int main()
    {
      n=rdn();
      if(n<=300){S1::solve();return 0;}
      if(n<=1e5){S2::solve();return 0;}
      return 0;
    }
    View Code

    然后看了题解。

    因为有 “加入的字符和上一个不同” 的限制,所以考虑一段 x 的末尾后面能续上 x 的 nxt 数组,只有自己的 nxt 跳到了另一段 y 的末尾,满足 x 和 y 的字符与长度均相同。

    那个 nxt 就是把一段看做一个字符、相同看做两段的字符与长度均相同的 nxt 数组。

    一边跳 nxt 一边累计答案,方法是记录一个 lst 表示当前段已经有前 lst 个字符贡献过答案;如果遇到 c[ p+1 ] == c[ cr ] ( c[ ] 表示字符, p 表示跳到的 nxt ),那么能匹配上的是当前段的前 min( len[ p+1 ] , len[ cr ] ) 个字符(len 表示段长);其中之前没贡献过答案的就是本次要贡献答案的,贡献是 ( s[ p ] + lst + 1 ) 到 ( s[ p ] + min( len[ p+1 ] , len[ cr ] ) ) 的等差数列求和。然后把 lst 更新成 min( len[ p+1 ] , len[ cr ] ) 。

    如果第一段的字符和自己相同,而第一段的长度比自己小(大于等于自己的话,在跳 nxt 的时候已经用等差数列加过了。所以跳 nxt 的 break 条件放在贡献答案之后),那么还可以给答案贡献 ( len[ cr ] - lst ) 倍的 len[ 1 ] 。(注意是 ( len[ cr ] - lst ) 而不是 ( len[ cr ] - len[ 1 ] ) )并且这种情况的 nxt[ cr ] 应该等于 1 而不是 0 。

    把询问离线,在树上用全局变量维护当前的 c[ ] 和 nxt[ ] , dfs 一遍即可。这样复杂度不对,但可过。目前只写了这样。

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define ll long long
    using namespace std;
    int rdn()
    {
      int ret=0;bool fx=1;char ch=getchar();
      while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
      while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
      return fx?ret:-ret;
    }
    int Mn(int a,int b){return a<b?a:b;}
    int Mx(int a,int b){return a>b?a:b;}
    const int N=1e5+5,mod=998244353;
    int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
    
    int n,c[N],len[N],tc[N],tl[N],s[N],nt[N];
    int hd[N],xnt,to[N],nxt[N],ans[N];
    int cz(int l,int r)
    {
      if(l>r)return 0;
      return (ll)(l+r)*(r-l+1)/2%mod;
    }
    void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
    void dfs(int cr,int pr,int cd)
    {
      ans[cr]=pr;
      if(len[cr])
        {
          cd++; nt[cd]=0;/////
          tc[cd]=c[cr]; tl[cd]=len[cr]; s[cd]=s[cd-1]+len[cr];
          if(cd==1)
        {
          ans[cr]=cz(0,len[cr]-1); nt[cd]=0;
        }
          else
        {
          int p=nt[cd-1],lst=0;
          while(1)
            {
              if(tc[p+1]==c[cr])
            {
              int tp=Mn(tl[p+1],len[cr]);
              if(tp>lst)
                { ans[cr]=upt(ans[cr]+cz(s[p]+lst+1,s[p]+tp)); lst=tp;}
            }
              if(!p||(tc[p+1]==c[cr]&&tl[p+1]==len[cr]))break;
              p=nt[p];
            }
          if(tc[p+1]==c[cr]&&tl[p+1]==len[cr])
            nt[cd]=p+1;
          else if(!p&&tc[1]==c[cr]&&tl[1]<len[cr])
            ans[cr]=(ans[cr]+(ll)tl[1]*(len[cr]-lst))%mod,nt[cd]=1;
          //-lst not len[1]//nxt=1 not 0
        }
        }
      for(int i=hd[cr];i;i=nxt[i])
        dfs(to[i],ans[cr],cd);
    }
    int main()
    {
      n=rdn(); char ch[5];
      for(int i=1;i<=n;i++)
        {
          int op=rdn();
          if(op==2){ int d=rdn();add(d,i);continue;}
          len[i]=rdn(); scanf("%s",ch); c[i]=ch[0]-'a'+1;
          add(i-1,i);
        }
      dfs(0,0,0);
      for(int i=1;i<=n;i++)printf("%d
    ",ans[i]);
      return 0;
    }
    View Code

    然后参考这里的题解(和代码):https://www.cnblogs.com/zhoushuyu/p/10680094.html

    复杂度不对的原因是暴力跳 nxt 。可以建 “kmp自动机” ,就是 pr[ i ][ j ] 表示 i 位置后面接 j 字符的话 nxt 会跳到哪个位置。新的位置 i 继承它的 nxt 的 pr[ ][ ] ,i-1 的某个 pr[ ][ ] 值改为 i 。

    根据接上来的长度不同,即使字符一样, nxt 仍可能跳到不同的位置。所以每个位置开 26 个主席树维护接上各种字符的不同长度, nxt 会跳到哪个位置。

    边跳还要边统计答案。把这个信息也放到主席树上。

      答案由两部分构成。设当前段能匹配的长度为 len , 一部分答案是 1 ~ len 的等差数列求和,另一部分是 1 ~ len 对应的 nxt 位置的前缀长度求和。

      考虑已经做完当前段,让它给上一个位置的主席树一些更新。设当前段长为 cd , prs 表示到上一个位置为止的前缀段长。

      考虑原来的暴力,跳到一个字符相同的位置,可以给当前段的一个前缀的每个位置提供一种可能的贡献,这里需要把 1~cd 位置的 “可能贡献” 改成当前的 prs 。这样一定最优。

      所以把主席树上 1~cd 位置的值都改成 prs 。把 cd 位置的 nxt 改成当前段。求答案的时候,假设要匹配的段的长度是 cd2 ,那么它的 nxt 就是主席树 cd2 处记录的 nxt ,它的过程中答案就是主席树 1~cd2 位置的值的和。

    注意处理与第一段匹配的情况。需要记录 “当前段最长能匹配多长” 。这个顺便记录即可。就是每次要修改的时候,对应值都可以对当前段长 cd 取 max 。

    代码里 rt[ top ][ ch ] 表示 “通过 ch 的边进入 top 之后的种种可能” 。所以往下走的时候,是把 rt[ pr+1 ] 赋值给 rt[ top+1 ] ,用的就是 “通过当前字符从 pr 进入 pr+1 ” 的信息。(pr 表示当前位置的 nxt )

    注意主席树新开节点的时候把原来的信息搬过来。

    注意在外面枚举 0 点的出边,进入的时候把 rt[ 0 ][ ] 之类的改成初值。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define ll long long
    #define ls Ls[cr]
    #define rs Rs[cr]
    using namespace std;
    int rdn()
    {
      int ret=0;bool fx=1;char ch=getchar();
      while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
      while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
      return fx?ret:-ret;
    }
    int Mx(int a,int b){return a>b?a:b;}
    int Mn(int a,int b){return a<b?a:b;}
    const int N=1e5+5,M=5e6+5,K=30,mod=998244353;
    int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
    
    int n,m,cnt,hd[N],xnt,to[N],nxt[N],w[N],c[N],ps[N],ans[N];
    int rt[N][K],mxl[N][K],prs[N],top,tc,tl;
    int tot,Ls[M],Rs[M],sm[M],nt[M],tg[M],tim,dfn[M];
    void add(int x,int y,int cd,int ch)
    {to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;w[xnt]=cd;c[xnt]=ch;}
    int cal(int x){return (ll)(1+x)*x/2%mod;}
    int nwnd(int cr)
    {
      if(dfn[cr]==tim)return cr; tot++; dfn[tot]=tim;
      Ls[tot]=ls; Rs[tot]=rs;
      sm[tot]=sm[cr]; nt[tot]=nt[cr]; tg[tot]=tg[cr];return tot;
    }
    void cz(int &cr,int len,int k){ cr=nwnd(cr); sm[cr]=(ll)len*k%mod;}
    void pshd(int cr,int l,int mid,int r)
    {
      if(!tg[cr])return; int k=tg[cr]; tg[cr]=0;
      cz(ls,mid-l+1,k); cz(rs,r-mid,k); tg[ls]=tg[rs]=k;
    }
    void mdfy(int l,int r,int &cr,int R,int k,int p)
    {
      if(!cr||dfn[cr]!=tim)cr=nwnd(cr);//
      if(r<R){cz(cr,r-l+1,k);tg[cr]=k;return;}
      if(l==r){cz(cr,r-l+1,k);nt[cr]=p;return;}
      int mid=l+r>>1; pshd(cr,l,mid,r);
      mdfy(l,mid,ls,R,k,p);
      if(mid<R)mdfy(mid+1,r,rs,R,k,p);
      sm[cr]=upt(sm[ls]+sm[rs]);
    }
    int qry(int l,int r,int cr,int R,int &p)
    {
      if(!cr)return 0; if(r<R)return sm[cr];
      if(l==r){p=nt[cr]; return sm[cr];}
      int mid=l+r>>1; pshd(cr,l,mid,r);
      int ret=qry(l,mid,ls,R,p);
      if(mid<R)ret=upt(ret+qry(mid+1,r,rs,R,p));
      return ret;
    }
    void dfs(int cr,int cd,int ch)
    {
      prs[++top]=prs[top-1]+cd; int pr=0;
      if(top==1)ans[cr]=upt(ans[cr]+cal(cd-1)),tc=ch,tl=cd;
      else
        {
          ans[cr]=upt(ans[cr]+qry(1,m,rt[top][ch],cd,pr));
          ans[cr]=upt(ans[cr]+cal(Mn(cd,mxl[top][ch])));//Mn!!
          if(!pr&&tc==ch&&tl<cd)
        {
          if(cd>mxl[top][ch])
            ans[cr]=(ans[cr]+(ll)tl*(cd-mxl[top][ch]))%mod;
          pr=1;///////
        }
        }
      mxl[top][ch]=Mx(mxl[top][ch],cd);
      tim++; mdfy(1,m,rt[top][ch],cd,prs[top-1],top);
      for(int i=hd[cr];i;i=nxt[i])
        {
          memcpy(rt[top+1],rt[pr+1],sizeof rt[pr+1]);//pr+1!!!
          memcpy(mxl[top+1],mxl[pr+1],sizeof mxl[pr+1]);
          ans[to[i]]=ans[cr]; dfs(to[i],w[i],c[i]);
        }
      top--;
    }
    int main()
    {
      n=rdn(); char ch;
      for(int i=1;i<=n;i++)
        {
          int op=rdn(), x=rdn();
          if(op==1)
        {
          cin>>ch; ps[i]=++cnt; m=Mx(m,x);
          add(ps[i-1],ps[i],x,ch-'a'+1);
        }
          else ps[i]=ps[x];
        }
      for(int i=hd[0];i;i=nxt[i])
        {
          memset(rt[1],0,sizeof rt[1]);/////
          memset(mxl[1],0,sizeof mxl[1]);
          dfs(to[i],w[i],c[i]);
        }
      for(int i=1;i<=n;i++)printf("%d
    ",ans[ps[i]]);
      return 0;
    }
    View Code
  • 相关阅读:
    HDU 5059 Help him
    HDU 5058 So easy
    HDU 5056 Boring count
    HDU 5055 Bob and math problem
    HDU 5054 Alice and Bob
    HDU 5019 Revenge of GCD
    HDU 5018 Revenge of Fibonacci
    HDU 1556 Color the ball
    CodeForces 702D Road to Post Office
    CodeForces 702C Cellular Network
  • 原文地址:https://www.cnblogs.com/Narh/p/10708362.html
Copyright © 2011-2022 走看看