题目链接:HDU - 2896
题意:给你n个模式串,对应每个模式串由编号,给出m个文本串,然后你要输出对应所匹配出模式串序号,以及有多少个文本串中有模式串
思路:比起比较基本统计个数,这里我们可以用set或者map来统计模式串序号
这里总结一下新学习的last优化。 参考博客:(1)
last相当于一个超级fail指针:
因为我们只有到根节点时才会重新匹配一个字母,所以我们此时直接记录一个last ,直接结束当前匹配过程.直接省去原 Fail 指针到可以匹配的节点之间的距离.
同时结合路径压缩,在匹配时可以完全不使用原 Fail.可以看下参考博客里面的图片。
code: 详细注解
#include<bits/stdc++.h> const int N=100000+5; using namespace std; struct AC_automaton{ int trie[N][128];//字典树 int val[N];//字符串结尾标记 int fail[N];//失配指针 int last[N];//last[i]=j表j节点表示的单词是i节点单词的后缀,且j节点是单词节点(last优化) int tot;//编号 void init(){//初始化0号点 tot=1; val[0]=fail[0]=last[0]=0; memset(trie[0],0,sizeof(trie[0])); } void insert(char *s,int v){//构造trie与val数组,v需非0,表示一个单词节点 int len=strlen(s); int root=0; for(int i=0;i<len;i++){ int id=s[i]; if(trie[root][id]==0){ trie[root][id]=tot; memset(trie[tot],0,sizeof(trie[tot])); val[tot++]=0; } root=trie[root][id]; } val[root]=v; //编号 //val[root]++; 个数 } void build(){//构造fail与last queue<int> q; last[0]=fail[0]=0; //先把第0个部分放进去 for(int i=0;i<128;i++){ int root=trie[0][i]; if(root!=0){ //初始化 fail[root]=0; last[root]=0; q.push(root); } } while(!q.empty()){//bfs求fail int k=q.front(); q.pop(); //ASCII编码范围 for(int i=0;i<128;i++){ int u=trie[k][i];//被取出结点k的子结点 if(u==0) continue; q.push(u); int v=fail[k];//k位置的失配指针 //把子节点改成fail节点的子节点形成一个Trie图 while(v && trie[v][i]==0) v=fail[v]; fail[u]=trie[v][i];//得到其儿子的失配结点 //last指针表示“在它顶上的fail边所指向的一串节点中,第一个真正的结束节点” last[u]=val[fail[u]]?fail[u]:last[fail[u]]; } } } void print(int i,set<int> &st){//递归找到存在结点i后缀相同的前缀节点编号 if(val[i]){ if( st.find(i)==st.end() ) st.insert(val[i]); print(last[i],st); } } void query(char *s,set<int> &st){//匹配 int len=strlen(s); int j=0; for(int i=0;i<len;i++){ int id=s[i]; while(j && trie[j][id]==0) j=fail[j]; j=trie[j][id]; if(val[j]) print(j,st); else if(last[j]) print(last[j],st); } } }ac; char P[N]; char T[N]; set<int>::iterator it; int main(){ int n; scanf("%d",&n); ac.init(); for(int i=1;i<=n;i++){ scanf("%s",P); ac.insert(P,i); } ac.build(); int m; scanf("%d",&m); int total=0; for(int i=1;i<=m;i++){ scanf("%s",&T); set<int> st;//保存文本串已经匹配到的模式串编号 ac.query(T,st); if(!st.empty()){ total++; printf("web %d:",i); for(set<int>::iterator it=st.begin();it!=st.end();it++) printf(" %d",(*it)); printf("% "); } } printf("total: %d ",total); return 0; }