题意:给你n个字符串,求出在超过一半的字符串中出现的所有子串中最长的子串,按字典序输出。
这道题算是我的一个黑历史了吧,以前我的做法是对这n个字符串建广义后缀自动机,然后在自动机上dfs,交上去AC了,然而事后发现算法假了,出了个数据把自己给hack了...
之前写的太烂了,决定重写一遍。
正确的操作是对n个串倒序建广义后缀自动机,建好以后把每个串放到自动机上跑一遍,把所有覆盖到的状态结点打上标记(每个串只标记一次,用vis判重),记录每个状态在多少个串中出现过,然后在后缀树(fail树)上按字典序dfs一遍就好了。
注意每添加一个字符串,需要把last指向根节点,而且在每次往后添加结点的时候判断当前结点是否存在过,如果存在则需要特殊处理(源自洛谷zcysky大神的思路)
复杂度$O(nsqrt n)$,但上界很松,跑起来速度还是很快滴~
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=2e5+10,M=26; 5 int n,ka; 6 char s[105][1010]; 7 struct SAM { 8 int fa[N],go[N][M],mxl[N],last,tot,ch[N][M],pos[N],cc[N],nc,vis[N],cnt[N],mx; 9 int newnode(int l,int p) { 10 int u=++tot; 11 mxl[u]=l,pos[u]=p,cnt[u]=0; 12 memset(go[u],0,sizeof go[u]); 13 memset(ch[u],0,sizeof ch[u]); 14 return u; 15 } 16 void init() {tot=nc=0,last=newnode(0,-1);} 17 void add(int ch) { 18 cc[++nc]=ch; 19 int p=last; 20 if(go[p][ch]) { 21 int q=go[p][ch]; 22 if(mxl[q]==mxl[p]+1)last=q; 23 else { 24 int nq=newnode(mxl[p]+1,pos[q]); 25 memcpy(go[nq],go[q],sizeof go[q]); 26 fa[nq]=fa[q],fa[q]=nq; 27 for(; p&&go[p][ch]==q; p=fa[p])go[p][ch]=nq; 28 last=nq; 29 } 30 } else { 31 int np=last=newnode(mxl[p]+1,nc); 32 for(; p&&!go[p][ch]; p=fa[p])go[p][ch]=np; 33 if(!p)fa[np]=1; 34 else { 35 int q=go[p][ch]; 36 if(mxl[q]==mxl[p]+1)fa[np]=q; 37 else { 38 int nq=newnode(mxl[p]+1,pos[q]); 39 memcpy(go[nq],go[q],sizeof go[q]); 40 fa[nq]=fa[q],fa[q]=fa[np]=nq; 41 for(; p&&go[p][ch]==q; p=fa[p])go[p][ch]=nq; 42 } 43 } 44 } 45 } 46 void dfs(int u) { 47 if(mxl[u]==mx&&cnt[u]>n/2) { 48 for(int i=pos[u]; i>pos[u]-mxl[u]; --i)printf("%c",cc[i]+'a'); 49 puts(""); 50 } 51 for(int i=0; i<M; ++i)if(ch[u][i])dfs(ch[u][i]); 52 } 53 void run() { 54 for(int i=0; i<n; ++i) { 55 last=1; 56 int l=strlen(s[i]); 57 reverse(s[i],s[i]+l); 58 for(int j=0; j<l; ++j)add(s[i][j]-'a'); 59 } 60 memset(vis,-1,sizeof vis); 61 for(int i=0; i<n; ++i) 62 for(int j=0,u=1; s[i][j]; u=go[u][s[i][j]-'a'],++j) 63 for(int v=go[u][s[i][j]-'a']; v!=1&&vis[v]!=i; v=fa[v])vis[v]=i,++cnt[v]; 64 mx=-1; 65 for(int i=2; i<=tot; ++i)if(cnt[i]>n/2)mx=max(mx,mxl[i]); 66 for(int i=2; i<=tot; ++i)ch[fa[i]][cc[pos[i]-mxl[fa[i]]]]=i; 67 if(!~mx)puts("?"); 68 else { 69 memset(vis,0,sizeof vis); 70 dfs(1); 71 } 72 } 73 } sam; 74 int main() { 75 while(scanf("%d",&n),n) { 76 ka?puts(""):++ka; 77 sam.init(); 78 for(int i=0; i<n; ++i)scanf("%s",s[i]); 79 sam.run(); 80 } 81 return 0; 82 }
还有一种做法是利用后缀数组。
把这n个串用不同的字符连接在一起求后缀数组,并给每个后缀i赋一个值a[i]表示它是哪个字符串里的。然后对排好序的后缀进行尺取并维护区间不同值的个数,一旦区间不同值的个数>n/2,就输出长度为左右端点lcp的字符串。(需要尺取两次,第一次求出最大长度,第二次输出)
但是这样做可能会有重复的串被输出,怎么去重呢?用哈希固然可以,可有没有优雅一点的做法呢?当然。只要每次输出的时候记录一下当前子串的左端点la,下次准备输出的时候和la求一次lcp,如果lcp=最大长度的话,就跳过。
复杂度$O(nlogn+n)$
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=1e5+1000,mod=998244353; 5 char buf[N]; 6 int s[N],sa[N],buf1[N],buf2[N],c[N],n,m,k,rnk[N],ht[N],ST[N][20],Log[N],a[N],cnt,ka; 7 void Sort(int* x,int* y,int m) { 8 for(int i=0; i<m; ++i)c[i]=0; 9 for(int i=0; i<n; ++i)++c[x[i]]; 10 for(int i=1; i<m; ++i)c[i]+=c[i-1]; 11 for(int i=n-1; i>=0; --i)sa[--c[x[y[i]]]]=y[i]; 12 } 13 void da(int* s,int n,int m=1000) { 14 int *x=buf1,*y=buf2; 15 x[n]=y[n]=-1; 16 for(int i=0; i<n; ++i)x[i]=s[i],y[i]=i; 17 Sort(x,y,m); 18 for(int k=1; k<n; k<<=1) { 19 int p=0; 20 for(int i=n-k; i<n; ++i)y[p++]=i; 21 for(int i=0; i<n; ++i)if(sa[i]>=k)y[p++]=sa[i]-k; 22 Sort(x,y,m),p=1,y[sa[0]]=0; 23 for(int i=1; i<n; ++i)y[sa[i]]=x[sa[i-1]]==x[sa[i]]&&x[sa[i-1]+k]==x[sa[i]+k]?p-1:p++; 24 if(p==n)break; 25 swap(x,y),m=p; 26 } 27 } 28 void getht() { 29 for(int i=0; i<n; ++i)rnk[sa[i]]=i; 30 ht[0]=0; 31 for(int i=0,k=0; i<n; ++i) { 32 if(k)--k; 33 if(!rnk[i])continue; 34 for(; s[i+k]==s[sa[rnk[i]-1]+k]; ++k); 35 ht[rnk[i]]=k; 36 } 37 } 38 void initST() { 39 for(int i=1; i<n; ++i)ST[i][0]=ht[i]; 40 for(int j=1; (1<<j)<=n; ++j) 41 for(int i=1; i+(1<<j)-1<n; ++i) 42 ST[i][j]=min(ST[i][j-1],ST[i+(1<<(j-1))][j-1]); 43 } 44 int lcp(int l,int r) { 45 if(l==r)return n-sa[l]; 46 if(l>r)swap(l,r); 47 l++; 48 int k=Log[r-l+1]; 49 return min(ST[l][k],ST[r-(1<<k)+1][k]); 50 } 51 void add(int x,int f) { 52 if(!x)return; 53 if(!c[x])++cnt; 54 if(!(c[x]-=f))--cnt; 55 } 56 int main() { 57 Log[0]=-1; 58 for(int i=1; i<N; ++i)Log[i]=Log[i>>1]+1; 59 while(scanf("%d",&m),m) { 60 if(ka++)puts(""); 61 memset(a,0,sizeof a); 62 n=0; 63 for(int i=0; i<m; ++i) { 64 if(i)s[n++]='z'+i; 65 scanf("%s",buf),k=strlen(buf); 66 for(int j=0; j<k; ++j)a[n]=i+1,s[n++]=buf[j]; 67 } 68 s[n]=0; 69 da(s,n),getht(),initST(); 70 memset(c,0,sizeof c); 71 cnt=0; 72 int mx=0; 73 for(int i=0,j=0; i<n; ++i) { 74 if(!a[sa[i]])break; 75 for(; j<n&&cnt<=m/2; ++j)add(a[sa[j]],1); 76 add(a[sa[i]],-1); 77 mx=max(mx,lcp(i,j-1)); 78 } 79 if(!mx)puts("?"); 80 else { 81 for(int i=0,j=0,k,la=-1; i<n; ++i) { 82 if(!a[sa[i]])break; 83 for(; j<n&&cnt<=m/2; ++j)add(a[sa[j]],1); 84 if(lcp(i,j-1)==mx) { 85 if(!~la||lcp(la,j-1)!=mx) { 86 for(k=0; k<lcp(i,j-1); ++k)printf("%c",s[sa[i]+k]); 87 puts(""); 88 } 89 la=i; 90 } 91 add(a[sa[i]],-1); 92 } 93 } 94 } 95 return 0; 96 }