模板1. 给出模式串和文本串,文本串长度小于1e6,模式串长度之和小于1e6,求文本串中有多少模式串出现。
题目链接:https://www.luogu.org/problem/P3808
AC code:
/* luoguP3808 (AC自动机模板题) 求文本串中有多少模式串出现 */ #include<cstdio> #include<queue> #include<cstring> #include<algorithm> #include<cstdlib> using namespace std; const int maxn=1e6+5; //key表示有多少个单词以该结点结尾 int fail[maxn],trie[maxn][30],key[maxn]; int n,cnt; char s[maxn]; //在Trie树中插入结点 void build(char *s){ int len=strlen(s),u=0; for(int i=0;i<len;++i){ int t=s[i]-'a'; if(!trie[u][t]){ ++cnt; memset(trie[cnt],0,sizeof(trie[cnt])); fail[cnt]=0,key[cnt]=0; trie[u][t]=cnt; } u=trie[u][t]; } key[u]+=1; } //构建fail指针 void get_fail(){ queue<int> que; for(int i=0;i<26;++i){ //提前处理第二层 if(trie[0][i]){ fail[trie[0][i]]=0; //指向根节点 que.push(trie[0][i]); } } while(!que.empty()){ //bfs求所有子结点 int u=que.front(); que.pop(); for(int i=0;i<26;++i){ if(trie[u][i]){ fail[trie[u][i]]=trie[fail[u]][i]; que.push(trie[u][i]); } else{ trie[u][i]=trie[fail[u]][i]; } } } } //AC自动机匹配 int query(char *s){ int len=strlen(s),u=0,ans=0; for(int i=0;i<len;++i){ int t=s[i]-'a'; u=trie[u][t]; for(int j=u;j&&key[j]!=-1;j=fail[j]){ //循环求解 ans+=key[j]; key[j]=-1; } } return ans; } int main(){ memset(trie[0],0,sizeof(trie[0])); key[0]=0; cnt=0; scanf("%d",&n); for(int i=1;i<=n;++i){ scanf("%s",s); build(s); } fail[0]=0; //结束标志 get_fail(); //求失配指针 scanf("%s",s); printf("%d ",query(s)); return 0; }
模板2. luoguP5357(二次增强),输出模式串在文本串中出现的次数(有相同的模式串)(拓扑排序优化AC自动机)
与增强版有两个改进地方。
第一:有相同的模式串,所以每个结点不能直接用key存储以该结点结束的模式串下标。这里给结点一个标记key[u],这个标记是以该结点为结束的模式串的最小下标key[u]=k,之后以该结点结束的其它字符串都记录此标记ind[k]=key[u],输出时按标记输出即可ans[ind[i]]。
第二:如果数据毒瘤,比如aaaaa...aaa,那么这样暴力去跳fail的最坏时间复杂度是O(模式串长度 · 文本串长度),因为对于每一次跳fail我们都只使深度减1,那样深度(深度最深是模式串长度)是多少,每一次跳的时间复杂度就是多少。那么还要乘上文本串长度,就几乎是 O(模式串长度 · 文本串长度)的了。
但是模板为什么复杂度是O(模式串总长)呢?因为每一个Trie上的点都只会经过一次(打了标记fail=-1),但这里每一个点就不止经过一次了,所以时间复杂度就爆炸了。
解决办法是应用拓扑排序,也就是对结点计数时先只是添加计数值res,并不暴力去跳fail,等query结束之后,利用拓扑排序去将计数值上传。详见代码(有注释)
AC code:
/* luoguP5357(二次增强) 输出模式串在文本串中出现的次数 拓扑排序优化AC自动机 */ #include<cstdio> #include<iostream> #include<cstring> #include<queue> #include<algorithm> #include<cstdlib> #include<string> using namespace std; const int maxn=2e6+5; char s[maxn]; //ind[i]记录模式串i所在结点的标记 int n,ans[200005],ind[200005],cnt=0; //key[u]记录结点u的标记,该标记可能对应多个模式串(可能有相同的模式串) //in[u]记录结点u的入度,res[u]记录结点u被加的次数 int fail[maxn],trie[maxn][30],key[maxn],in[maxn],res[maxn]; void build(char *s,int k){ int len=strlen(s),u=0; for(int i=0;i<len;++i){ int t=s[i]-'a'; if(!trie[u][t]){ ++cnt; memset(trie[cnt],0,sizeof(trie[cnt])); fail[cnt]=0,key[cnt]=0,in[cnt]=0,res[cnt]=0; trie[u][t]=cnt; } u=trie[u][t]; } //每个结点最多被标记一次 if(!key[u]) key[u]=k; ind[k]=key[u]; //模式串k所在结点的标记为key[u] } void get_fail(){ queue<int> que; for(int i=0;i<26;++i){ if(trie[0][i]){ fail[trie[0][i]]=0; que.push(trie[0][i]); } } while(!que.empty()){ int u=que.front();que.pop(); for(int i=0;i<26;++i){ if(trie[u][i]){ fail[trie[u][i]]=trie[fail[u]][i]; ++in[fail[trie[u][i]]]; //入度加一 que.push(trie[u][i]); } else{ trie[u][i]=trie[fail[u]][i]; } } } } void query(char *s){ int len=strlen(s),u=0; for(int i=0;i<len;++i){ int t=s[i]-'a'; u=trie[u][t]; ++res[u]; //不用循环查询,直接计数加的次数 } } void topu(){ //拓扑排序 queue<int> que; for(int i=1;i<=cnt;++i) if(!in[i]) que.push(i); //入队 while(!que.empty()){ int u=que.front();que.pop(); ans[key[u]]=res[u]; //将结点u的计数值加到ans数组上 int v=fail[u];--in[v]; //减入度 res[v]+=res[u]; //将计数值传递 if(!in[v]) que.push(v); } } int main(){ scanf("%d",&n); memset(trie[0],0,sizeof(trie[0])); cnt=0,key[0]=0,in[0]=0,res[0]=0; for(int i=1;i<=n;++i){ scanf("%s",s); build(s,i); } fail[0]=0; // 结束标志 get_fail(); //得到fail数组 scanf("%s",s); query(s); topu(); for(int i=1;i<=n;++i) //输出 printf("%d ",ans[ind[i]]); return 0; }