首先看一个广义SAM的经典应用:
BZOJ3277&BZOJ3473
给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串(包括本身)。
对于多串问题,普通SAM已经无法胜任。有各种应对这类多串问题的方法:
(1) 直接建SAM,每次插入新串时将lst设为1即可,不会证明正确性。
(2) 和SA一样在多个串接在一起中间加入分隔符'#',然后直接建SAM。
(3) 离线做法:对所有串建出Trie,然后BFS遍历Trie,每个点在父亲的基础上extend即可。时空复杂度$O(|A||T|)$。
(4) 在线做法:模板如下,时间复杂度$O(|A||T|+G(T))$,空间复杂度$O(|A||T|)$,其中|T|是Trie的大小,|A|是字符集大小,G(T)是Trie上所有叶子深度之和。
结论:如果题目是以多个模板串形式给出的,建议使用在线做法(离线做法虽然复杂度优秀但需要先建出Trie所以大部分情况常数会较大)。如果是直接以Trie形式给出的,那么直接用离线算法就好了。
1 int work(int p,int c){ 2 int nq=++nd,q=son[p][c]; mx[nq]=mx[p]+1; 3 fa[nq]=fa[q]; fa[q]=nq; 4 memcpy(son[nq],son[q],sizeof(son[q])); 5 while (p && son[p][c]==q) son[p][c]=nq,p=fa[p]; 6 return nq; 7 } 8 9 int ext(int p,int c){ 10 if (son[p][c]){ 11 int q=son[p][c]; 12 if (mx[q]==mx[p]+1) return q; else return work(p,c); 13 }else{ 14 int np=++nd; mx[np]=mx[p]+1; 15 while (p && !son[p][c]) son[p][c]=np,p=fa[p]; 16 if (!p) fa[np]=1; 17 else{ 18 int q=son[p][c]; 19 if (mx[q]==mx[p]+1) fa[np]=q; else fa[np]=work(p,c); 20 } 21 return np; 22 } 23 }
我们已经对给定的n个串所构成的Trie建立了SAM,现在需要得到每个点所代表的子串在多少个字符串中出现过。
这里就是每次暴力跳parent更新(打标记),当然如果发现一个点已经被当前字符串更新过了就不需要再往上跳了。
这样的复杂度据说是$O(n^frac{3}{2})$,因为每次插入一个串涉及到的节点数显然不超过$min(|S|^2,|T|)$,其中|T|为当前SAM大小,|S|为串长。所以当构造$sqrt{n}$个长度为$sqrt{n}$的串时复杂度达到最坏情况$(n^frac{3}{2})$。
这样我们就求得了每个结点所表示的串的出现次数,接着在parent树上从根到叶做一次DP统计答案即可。
1 void solve(){ 2 int u; ll ans; 3 rep(i,1,n){ 4 u=1; 5 rep(j,0,len[i]){ 6 u=son[u][s[i][j]-'a']; int p=u; 7 while (p && vis[p]!=i) tot[p]++,vis[p]=i,p=fa[p]; 8 } 9 } 10 radix(); 11 rep(i,2,nd) u=q[i],f[u]=f[fa[u]]+(tot[u]>=k ? mx[u]-mx[fa[u]] : 0); 12 rep(i,1,n){ 13 u=1; ans=0; 14 rep(j,0,len[i]) u=son[u][s[i][j]-'a'],ans+=f[u]; 15 printf("%lld ",ans); 16 } 17 }
完整代码:
1 #include<cstdio> 2 #include<string> 3 #include<cstring> 4 #include<algorithm> 5 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 6 typedef long long ll; 7 using namespace std; 8 9 const int N=300010; 10 int n,k,lst=1,nd=1,len[N],son[N][27],tot[N],f[N],c[N],q[N],fa[N],mx[N],vis[N]; 11 char ss[N]; 12 string s[N]; 13 14 int work(int p,int c){ 15 int nq=++nd,q=son[p][c]; mx[nq]=mx[p]+1; 16 fa[nq]=fa[q]; fa[q]=nq; 17 memcpy(son[nq],son[q],sizeof(son[q])); 18 while (p && son[p][c]==q) son[p][c]=nq,p=fa[p]; 19 return nq; 20 } 21 22 int ext(int p,int c){ 23 if (son[p][c]){ 24 int q=son[p][c]; 25 if (mx[q]==mx[p]+1) return q; else return work(p,c); 26 }else{ 27 int np=++nd; mx[np]=mx[p]+1; 28 while (p && !son[p][c]) son[p][c]=np,p=fa[p]; 29 if (!p) fa[np]=1; 30 else{ 31 int q=son[p][c]; 32 if (mx[q]==mx[p]+1) fa[np]=q; else fa[np]=work(p,c); 33 } 34 return np; 35 } 36 } 37 38 void radix(){ 39 rep(i,1,nd) c[mx[i]]++; 40 rep(i,1,nd) c[i]+=c[i-1]; 41 for (int i=nd; i; i--) q[c[mx[i]]--]=i; 42 } 43 44 void solve(){ 45 int u; ll ans; 46 rep(i,1,n){ 47 u=1; 48 rep(j,0,len[i]){ 49 u=son[u][s[i][j]-'a']; int p=u; 50 while (p && vis[p]!=i) tot[p]++,vis[p]=i,p=fa[p]; 51 } 52 } 53 radix(); 54 rep(i,2,nd) u=q[i],f[u]=f[fa[u]]+(tot[u]>=k ? mx[u]-mx[fa[u]] : 0); 55 rep(i,1,n){ 56 u=1; ans=0; 57 rep(j,0,len[i]) u=son[u][s[i][j]-'a'],ans+=f[u]; 58 printf("%lld ",ans); 59 } 60 } 61 62 int main(){ 63 freopen("bzoj3277.in","r",stdin); 64 freopen("bzoj3277.out","w",stdout); 65 scanf("%d%d",&n,&k); 66 rep(i,1,n){ 67 scanf("%s",ss); s[i]=string(ss); len[i]=strlen(ss)-1; 68 lst=1; rep(j,0,len[i]) lst=ext(lst,s[i][j]-'a'); 69 } 70 solve(); 71 return 0; 72 }
BZOJ2780
有n个大串和m个询问,每次给出一个字符串s询问在多少个大串中出现过。
1 #include<cstdio> 2 #include<string> 3 #include<cstring> 4 #include<algorithm> 5 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 6 typedef long long ll; 7 using namespace std; 8 9 const int N=300010; 10 int n,m,lst=1,nd=1,len[N],son[N][27],tot[N],f[N],c[N],q[N],fa[N],mx[N],vis[N]; 11 char ss[N]; 12 string s[N]; 13 14 int work(int p,int c){ 15 int nq=++nd,q=son[p][c]; mx[nq]=mx[p]+1; 16 fa[nq]=fa[q]; fa[q]=nq; 17 memcpy(son[nq],son[q],sizeof(son[q])); 18 while (p && son[p][c]==q) son[p][c]=nq,p=fa[p]; 19 return nq; 20 } 21 22 int ext(int p,int c){ 23 if (son[p][c]){ 24 int q=son[p][c]; 25 if (mx[q]==mx[p]+1) return q; else return work(p,c); 26 }else{ 27 int np=++nd; mx[np]=mx[p]+1; 28 while (p && !son[p][c]) son[p][c]=np,p=fa[p]; 29 if (!p) fa[np]=1; 30 else{ 31 int q=son[p][c]; 32 if (mx[q]==mx[p]+1) fa[np]=q; else fa[np]=work(p,c); 33 } 34 return np; 35 } 36 } 37 38 void solve(){ 39 int u; 40 rep(i,1,n){ 41 u=1; 42 rep(j,0,len[i]){ 43 u=son[u][s[i][j]-'a']; int p=u; 44 while (p && vis[p]!=i) tot[p]++,vis[p]=i,p=fa[p]; 45 } 46 } 47 rep(i,1,m){ 48 u=1; scanf("%s",ss); int len=strlen(ss)-1; 49 rep(j,0,len) u=son[u][ss[j]-'a']; 50 printf("%d ",tot[u]); 51 } 52 } 53 54 int main(){ 55 freopen("bzoj2780.in","r",stdin); 56 freopen("bzoj2780.out","w",stdout); 57 scanf("%d%d",&n,&m); 58 rep(i,1,n){ 59 scanf("%s",ss); s[i]=string(ss); len[i]=strlen(ss)-1; 60 lst=1; rep(j,0,len[i]) lst=ext(lst,s[i][j]-'a'); 61 } 62 solve(); 63 return 0; 64 }
BZOJ4566
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=1000010; 9 int n,k,lst=1,nd=1,len,son[N][27],c[N],q[N],fa[N],mx[N],d1[N],d2[N]; 10 char s[N]; 11 ll ans; 12 13 int work(int p,int c){ 14 int nq=++nd,q=son[p][c]; mx[nq]=mx[p]+1; 15 fa[nq]=fa[q]; fa[q]=nq; 16 memcpy(son[nq],son[q],sizeof(son[q])); 17 while (p && son[p][c]==q) son[p][c]=nq,p=fa[p]; 18 return nq; 19 } 20 21 int ext(int p,int c){ 22 if (son[p][c]){ 23 int q=son[p][c]; 24 if (mx[q]==mx[p]+1) return q; else return work(p,c); 25 }else{ 26 int np=++nd; mx[np]=mx[p]+1; 27 while (p && !son[p][c]) son[p][c]=np,p=fa[p]; 28 if (!p) fa[np]=1; 29 else{ 30 int q=son[p][c]; 31 if (mx[q]==mx[p]+1) fa[np]=q; else fa[np]=work(p,c); 32 } 33 return np; 34 } 35 } 36 37 void radix(){ 38 rep(i,1,nd) c[mx[i]]++; 39 rep(i,1,nd) c[i]+=c[i-1]; 40 for (int i=nd; i; i--) q[c[mx[i]]--]=i; 41 } 42 43 int main(){ 44 freopen("bzoj4566.in","r",stdin); 45 freopen("bzoj4566.out","w",stdout); 46 scanf("%s",s+1); len=strlen(s+1); lst=1; 47 rep(i,1,len) lst=ext(lst,s[i]-'a'),d1[lst]++; 48 scanf("%s",s+1); len=strlen(s+1); lst=1; 49 rep(i,1,len) lst=ext(lst,s[i]-'a'),d2[lst]++; 50 radix(); 51 for (int i=nd; i; i--) d1[fa[q[i]]]+=d1[q[i]],d2[fa[q[i]]]+=d2[q[i]]; 52 rep(i,1,nd) ans+=1ll*(mx[i]-mx[fa[i]])*d1[i]*d2[i]; 53 printf("%lld ",ans); 54 return 0; 55 }
BZOJ3756
在线:
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=1600010,M=8000010; 9 char ch,s[M]; 10 ll ans,sm[N]; 11 int n,nd=1,x,len,pos[N],son[N][26],mx[N],fa[N],R[N],q[N],c[N]; 12 13 int work(int p,int c){ 14 int nq=++nd,q=son[p][c]; mx[nq]=mx[p]+1; 15 fa[nq]=fa[q]; fa[q]=nq; memcpy(son[nq],son[q],sizeof(son[q])); 16 while (p && son[p][c]==q) son[p][c]=nq,p=fa[p]; 17 return nq; 18 } 19 20 int ext(int p,int c){ 21 if (son[p][c]){ 22 int q=son[p][c]; 23 if (mx[q]==mx[p]+1){ R[q]++; return q; } 24 else{ int t=work(p,c); R[t]++; return t; } 25 }else{ 26 int np=++nd; mx[np]=mx[p]+1; R[np]=1; 27 while (p && !son[p][c]) son[p][c]=np,p=fa[p]; 28 if (!p){ fa[np]=1; return np; } 29 int q=son[p][c]; 30 if (mx[q]==mx[p]+1) fa[np]=q; else fa[np]=work(p,c); 31 return np; 32 } 33 } 34 35 int main(){ 36 freopen("bzoj3756.in","r",stdin); 37 freopen("bzoj3756.out","w",stdout); 38 scanf("%d",&n); pos[1]=1; 39 rep(i,2,n) scanf("%d %c",&x,&ch),pos[i]=ext(pos[x],ch-'a'); 40 scanf("%s",s+1); len=strlen(s+1); 41 rep(i,1,nd) c[mx[i]]++; 42 rep(i,1,n) c[i]+=c[i-1]; 43 for (int i=nd; i; i--) q[c[mx[i]]--]=i; 44 for (int i=nd; i; i--) R[fa[q[i]]]+=R[q[i]]; 45 R[0]=R[1]=0; 46 rep(i,1,nd){ int x=q[i]; sm[x]=sm[fa[x]]+1ll*(mx[x]-mx[fa[x]])*R[x]; } 47 int x=1,l=0; 48 rep(i,1,len){ 49 if (son[x][s[i]-'a']) l++,x=son[x][s[i]-'a']; 50 else{ 51 while (x && !son[x][s[i]-'a']) x=fa[x]; 52 if (x) l=mx[x]+1,x=son[x][s[i]-'a']; else l=0,x=1; 53 } 54 if (x>1) ans+=sm[fa[x]]+1ll*(l-mx[fa[x]])*R[x]; 55 } 56 printf("%lld",ans); 57 return 0; 58 }
离线:
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=1600010,M=8000010; 9 char ch,s[M]; 10 ll ans,sm[N]; 11 int n,nd=1,x,len,pos[N],pre[N],sn[N][27],son[N][27],mx[N],fa[N],R[N],q[N],c[N]; 12 13 int ext(int p,int c){ 14 int np=++nd; mx[np]=mx[p]+1; R[np]=1; 15 while (!son[p][c] && p) son[p][c]=np,p=fa[p]; 16 if (!p) fa[np]=1; 17 else{ 18 int q=son[p][c]; 19 if (mx[q]==mx[p]+1) fa[np]=q; 20 else{ 21 int nq=++nd; mx[nq]=mx[p]+1; 22 memcpy(son[nq],son[q],sizeof(son[q])); 23 fa[nq]=fa[q]; fa[np]=fa[q]=nq; 24 while (p && son[p][c]==q) son[p][c]=nq,p=fa[p]; 25 } 26 } 27 return np; 28 } 29 30 int main(){ 31 scanf("%d",&n); pos[1]=1; q[1]=1; 32 rep(i,2,n) scanf("%d %c",&x,&ch),sn[x][ch-'a']=i,pre[i]=x,s[i]=ch; 33 for (int st=0,ed=1; st!=ed; ){ 34 int x=q[++st]; pos[x]=ext(pos[pre[x]],s[x]-'a'); 35 rep(i,0,25) if (sn[x][i]) q[++ed]=sn[x][i]; 36 } 37 scanf("%s",s+1); len=strlen(s+1); 38 rep(i,1,nd) c[mx[i]]++; 39 rep(i,1,n) c[i]+=c[i-1]; 40 for (int i=nd; i; i--) q[c[mx[i]]--]=i; 41 for (int i=nd; i; i--) R[fa[q[i]]]+=R[q[i]]; 42 R[0]=R[1]=0; 43 rep(i,1,nd){ int x=q[i]; sm[x]=sm[fa[x]]+1ll*(mx[x]-mx[fa[x]])*R[x]; } 44 int x=1,l=0; 45 rep(i,1,len){ 46 if (son[x][s[i]-'a']) l++,x=son[x][s[i]-'a']; 47 else{ 48 while (x && !son[x][s[i]-'a']) x=fa[x]; 49 if (x) l=mx[x]+1,x=son[x][s[i]-'a']; else l=0,x=1; 50 } 51 if (x>1) ans+=sm[fa[x]]+1ll*(l-mx[fa[x]])*R[x]; 52 } 53 printf("%lld",ans); 54 return 0; 55 }
另外这题由于出题人数据有误导致只有在线做法能过。
CF666E
建出广义后缀树,倍增找到代表询问串的节点,然后就是查询它在那个串中的|Right|最大,用可持久化线段树合并即可。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define ls v[x].lc 5 #define rs v[x].rc 6 #define lson ls,L,mid 7 #define rson rs,mid+1,R 8 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 9 using namespace std; 10 11 const int N=100010,M=1000010; 12 char ss[N],s[M]; 13 int n,m,Q,l1,r1,l2,r2,nd=1,cnt,pos[M],L[M]; 14 int mx[M],q[M],fa[M],c[M],son[M][27],rt[M],f[M][22]; 15 struct P{ int x,id; }; 16 struct Tr{ int lc,rc; P p; }v[N*30]; 17 18 bool operator <(const P &a,const P &b){ return a.x==b.x ? a.id>b.id : a.x<b.x; } 19 20 int work(int p,int c){ 21 int nq=++nd,q=son[p][c]; mx[nq]=mx[p]+1; 22 fa[nq]=fa[q]; fa[q]=nq; memcpy(son[nq],son[q],sizeof(son[q])); 23 while (p && son[p][c]==q) son[p][c]=nq,p=fa[p]; 24 return nq; 25 } 26 27 int ext(int p,int c){ 28 if (son[p][c]){ 29 int q=son[p][c]; 30 if (mx[q]==mx[p]+1) return q; else return work(p,c); 31 }else{ 32 int np=++nd; mx[np]=mx[p]+1; 33 while (p && !son[p][c]) son[p][c]=np,p=fa[p]; 34 if (!p){ fa[np]=1; return np; } 35 int q=son[p][c]; 36 if (mx[q]==mx[p]+1) fa[np]=q; else fa[np]=work(p,c); 37 return np; 38 } 39 } 40 41 void upd(int x){ v[x].p=max(v[ls].p,v[rs].p); } 42 43 int merge(int x,int y,int L,int R){ 44 if (!x || !y) return x|y; 45 int k=++cnt; 46 if (L==R){ v[k].p=(P){v[x].p.x+v[y].p.x,L}; return k; } 47 int mid=(L+R)>>1; 48 v[k].lc=merge(v[x].lc,v[y].lc,L,mid); 49 v[k].rc=merge(v[x].rc,v[y].rc,mid+1,R); 50 upd(k); return k; 51 } 52 53 void mdf(int &x,int L,int R,int k){ 54 if (!x) x=++cnt; 55 if (L==R){ v[x].p=(P){v[x].p.x+1,k}; return; } 56 int mid=(L+R)>>1; 57 if (k<=mid) mdf(lson,k); else mdf(rson,k); 58 upd(x); 59 } 60 61 P que(int x,int L,int R,int l,int r){ 62 if (!x) return (P){0,0}; 63 if (L==l && r==R) return v[x].p; 64 int mid=(L+R)>>1; 65 if (r<=mid) return que(lson,l,r); 66 if (l>mid) return que(rson,l,r); 67 return max(que(lson,l,mid),que(rson,mid+1,r)); 68 } 69 70 int get(int x,int k){ for (int i=20; ~i; i--) if (mx[f[x][i]]>=k) x=f[x][i]; return x; } 71 72 int main(){ 73 freopen("666E.in","r",stdin); 74 freopen("666E.out","w",stdout); 75 scanf("%s%d",s+1,&m); n=strlen(s+1); 76 rep(i,1,m){ 77 scanf("%s",ss+1); int len=strlen(ss+1),x=1; 78 rep(j,1,len) x=ext(x,ss[j]-'a'),mdf(rt[x],1,m,i); 79 } 80 int x=1,l=0; 81 rep(i,1,n){ 82 int c=s[i]-'a'; 83 while (x && !son[x][c]) x=fa[x],l=mx[x]; 84 if (son[x][c]) L[i]=++l,x=son[x][c],pos[i]=x; else x=1,L[i]=l=0; 85 } 86 int t=0; 87 rep(i,1,nd) c[mx[i]]++,t=max(t,mx[i]); 88 rep(i,1,t) c[i]+=c[i-1]; 89 for (int i=nd; i; i--) q[c[mx[i]]--]=i; 90 for (int i=nd; i; i--) rt[fa[q[i]]]=merge(rt[fa[q[i]]],rt[q[i]],1,m); 91 rt[1]=0; 92 rep(i,1,nd) f[i][0]=fa[i]; 93 rep(j,1,20) rep(i,1,nd) f[i][j]=f[f[i][j-1]][j-1]; 94 for (scanf("%d",&Q); Q--; ){ 95 scanf("%d%d%d%d",&l1,&r1,&l2,&r2); 96 if (L[r2]<r2-l2+1){ printf("%d 0 ",l1); continue; } 97 P t=que(rt[get(pos[r2],r2-l2+1)],1,m,l1,r1); 98 if (t.x) printf("%d %d ",t.id,t.x); else printf("%d 0 ",l1); 99 } 100 return 0; 101 }