zoukankan      html  css  js  c++  java
  • [BZOJ3277/BZOJ3473] 串

    [BZOJ3277] 串

    Description

    现在给定你n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串(注意包括本身)。

    Solution

    首先将所有串连接起来,预处理出后缀数组和高度数组。

    显然直接主席树可以很容易做到 (O(n log^2 n)) 。对于每一个后缀的位置,二分一个 LCP 长度,找到这个 LCP 长度对应的区间,检查这个区间是否合法来调节二分边界。

    注意在这个做法里,瓶颈不在于主席树,因为主席树的功能完全可以用双指针预处理一个数组来替代。瓶颈在于,实质上使用了一个二分套二分的做法。

    但我们有更好的做法。

    引理:按照原始顺序,如果第 (i) 个后缀有 (x) 个前缀能被 (k) 个串包含,那么第 (i+1) 个后缀至少有 (x-1) 个前缀能被 (k) 个串包含。

    那么我们先用双指针预处理 (jmp[i]) 代表按照后缀排序,最大的 (j) 使得 ([j,i]) 这个后缀区间合法。

    到第 (i) 个后缀的时候我们就从后缀 (i-1) 的答案开始向上枚举,用二分+ST表找出它左右边第一个高度比当前枚举值小的位置,判断这个区间的合法性来决定是否继续枚举,均摊时间复杂度 (O(nlogn))

    (O(log n)) 解法

    #include <bits/stdc++.h>
    using namespace std;
    
    #define int long long
    const int N = 400005;
    
    int n,m=N/2,sa[N],y[N],u[N],v[N],o[N],r[N],h[N],T,nstr,k;
    int str[N],Log2[N],bel[N],buf[N],bcnt,jmp[N],mx[N],ans[N],tow[N];
    char tstr[N];
    
    struct St {
        int a[N][21];
        void build(int *src,int n) {
            for(int i=1;i<=n;i++) a[i][0]=src[i];
            for(int i=1;i<=20;i++)
                for(int j=1;j<=n-(1<<i)+1;j++)
                    a[j][i]=min(a[j][i-1],a[j+(1<<(i-1))][i-1]);
        }
        int query(int l,int r) {
            if(l>r) return 0;
            int j=Log2[r-l+1];
            return min(a[l][j],a[r-(1<<j)+1][j]);
        }
    } st;
    
    int lbound(int cen,int val) {
        int l=1,r=cen;
        while(r>l) {
            int mid=(l+r)/2;
            if(st.query(mid+1,cen)>=val) r=mid;
            else l=mid+1;
        }
        return l;
    }
    
    int rbound(int cen,int val) {
        int l=cen+1,r=n+1;
        while(r>l) {
            int mid=(l+r)/2;
            if(st.query(cen+1,mid)>=val) l=mid+1;
            else r=mid;
        }
        return l-1;
    }
    
    signed main(){
        for(int i=1;i<=200000;i++) Log2[i]=log2(i);
        scanf("%lld%lld",&nstr,&k);
        for(int i=1;i<=nstr;i++) {
            scanf("%s",tstr);
            int len=strlen(tstr);
            for(int j=0;j<len;j++) str[j+n+1]=tstr[j],bel[j+n+1]=i,tow[j+n+1]=n+len;
            n+=len+1;
            str[n]=127+i;
        }
    
        for(int i=1;i<=n;i++) u[str[i]]++;
        for(int i=1;i<=m;i++) u[i]+=u[i-1];
        for(int i=n;i>=1;i--) sa[u[str[i]]--]=i;
        r[sa[1]]=1;
        for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]);
    
        for(int l=1;r[sa[n]]<n;l<<=1) {
            memset(u,0,sizeof u);
            memset(v,0,sizeof v);
            memcpy(o,r,sizeof r);
            for(int i=1;i<=n;i++) u[r[i]]++, v[r[i+l]]++;
            for(int i=1;i<=n;i++) u[i]+=u[i-1], v[i]+=v[i-1];
            for(int i=n;i>=1;i--) y[v[r[i+l]]--]=i;
            for(int i=n;i>=1;i--) sa[u[r[y[i]]]--]=y[i];
            r[sa[1]]=1;
            for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+((o[sa[i]]!=o[sa[i-1]])||(o[sa[i]+l]!=o[sa[i-1]+l]));
        }
        {
            int i,j,k=0;
            for(int i=1;i<=n;h[r[i++]]=k)
                for(k?k--:0,j=sa[r[i]-1];str[i+k]==str[j+k];k++);
        }
    
        st.build(h,n);
    
        bcnt=1;
        buf[bel[sa[n]]]++;
        for(int i=n,j=n;i>=1;--i) {
            while(bcnt<k && j>0) {
                --j;
                if(buf[bel[sa[j]]]==0) ++bcnt;
                buf[bel[sa[j]]]++;
            }
            jmp[i]=j;
            if(buf[bel[sa[i]]]==1) --bcnt;
            buf[bel[sa[i]]]--;
        }
       // for(int i=1;i<=n;i++) cout<<jmp[i]<<" "; cout<<endl;
        for(int i=1;i<=n;i++) {
            for(int j=max(1ll,mx[i-1]);j<=n;j++) {
                int lb=lbound(r[i],j), rb=rbound(r[i],j);
                //cout<<i<<" "<<r[i]<<" "<<j<<" "<<lb<<" "<<rb<<endl;
                if(jmp[rb]<lb || j>tow[i]-i+1) {
                    mx[i]=j-1;
                    break;
                }
            }
        }
        //for(int i=1;i<=n;i++) cout<<mx[i]<<" ";
        //cout<<endl;
        for(int i=1;i<=n;i++) {
            ans[bel[i]]+=mx[i];
        }
        for(int i=1;i<=nstr;i++) printf("%lld ",ans[i]);
    }
    

    (O(log^2 n)) 解法 (TLE)

    #include <bits/stdc++.h>
    using namespace std;
    
    #define int long long
    const int N = 400005;
    
    int n,m=N/2,sa[N],y[N],u[N],v[N],o[N],r[N],h[N],jmp[N],buf[N],bel[N],bcnt;
    int nstr,k;
    int str[N],ans[N],tow[N],LOG2[N];
    char tstr[N];
    
    struct St {
        int a[N][21];
        void build(int *src,int n) {
            for(int i=1;i<=n;i++) a[i][0]=src[i];
            for(int i=1;i<=20;i++)
                for(int j=1;j<=n-(1<<i)+1;j++)
                    a[j][i]=min(a[j][i-1],a[j+(1<<(i-1))][i-1]);
        }
        int query(int l,int r) {
            if(l>r) return 0;
            int j=LOG2[r-l+1];
            return min(a[l][j],a[r-(1<<j)+1][j]);
        }
    } st;
    
    int lbound(int cen,int val) {
        int l=1,r=cen;
        while(r-l) {
            int mid=(l+r)/2;
            if(st.query(mid+1,cen)>=val) r=mid;
            else l=mid+1;
        }
        return l;
    }
    
    int rbound(int cen,int val) {
        int l=cen+1,r=n+1;
        while(r-l) {
            int mid=(l+r)/2;
            if(st.query(cen+1,mid)>=val) l=mid+1;
            else r=mid;
        }
        return l-1;
    }
    
    signed main(){
        for(int i=1;i<=200000;i++) LOG2[i]=log2(i);
        scanf("%d%d",&nstr,&k);
        for(int i=1;i<=nstr;i++) {
            scanf("%s",tstr);
            int tstrlength = strlen(tstr);
            for(int j=0;j<tstrlength;j++)
                str[n+j+1]=tstr[j],bel[n+j+1]=i,tow[n+j+1]=n+tstrlength;
            n+=tstrlength+1;
            str[n]=127+i;
        }
    
        for(int i=1;i<=n;i++) u[str[i]]++;
        for(int i=1;i<=m;i++) u[i]+=u[i-1];
        for(int i=n;i>=1;i--) sa[u[str[i]]--]=i;
        r[sa[1]]=1;
        for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]);
    
        for(int l=1;r[sa[n]]<n;l<<=1) {
            memset(u,0,sizeof u);
            memset(v,0,sizeof v);
            memcpy(o,r,sizeof r);
            for(int i=1;i<=n;i++) u[r[i]]++, v[r[i+l]]++;
            for(int i=1;i<=n;i++) u[i]+=u[i-1], v[i]+=v[i-1];
            for(int i=n;i>=1;i--) y[v[r[i+l]]--]=i;
            for(int i=n;i>=1;i--) sa[u[r[y[i]]]--]=y[i];
            r[sa[1]]=1;
            for(int i=2;i<=n;i++) r[sa[i]]=r[sa[i-1]]+((o[sa[i]]!=o[sa[i-1]])||(o[sa[i]+l]!=o[sa[i-1]+l]));
        }
        {
            int i,j,k=0;
            for(int i=1;i<=n;h[r[i++]]=k)
                for(k?k--:0,j=sa[r[i]-1];str[i+k]==str[j+k];k++);
        }
        st.build(h,n);
        buf[bel[sa[n]]]=1; bcnt++;
        for(int i=n,j=n;i>=1;--i) {
            while(bcnt<k && j>0) {
                --j;
                if(buf[bel[sa[j]]]==0) bcnt++;
                buf[bel[sa[j]]]++;
            }
            jmp[i]=j;
            buf[bel[sa[i]]]--;
            if(buf[bel[sa[i]]]==0) bcnt--;
        }
    
        for(int i=1;i<=n;i++) {
            int l=1,r=tow[sa[i]]-sa[i]+2;
            while(r>l) {
                int mid=(l+r)/2;
                int lb=lbound(i,mid),rb=rbound(i,mid);
                if(jmp[rb]>=lb) l=mid+1;
                else r=mid;
            }
            //cout<<i<<" "<<l-1<<endl;
            ans[bel[sa[i]]]+=l-1;
        }
        for(int i=1;i<=nstr;i++) printf("%lld ",ans[i]);
    }
    
  • 相关阅读:
    自定义Python枚举
    解决Django跨域访问的问题
    BBS项目细节总结
    面向对象进阶
    面向对象
    三级菜单
    常用模块
    内置函数与匿名函数及递归
    迭代器和生成器
    函数
  • 原文地址:https://www.cnblogs.com/mollnn/p/11791950.html
Copyright © 2011-2022 走看看