题意
给定 (n) 个模式串 (s_i) 和一个文本串 (t) ,求有多少个不同的模式串在文本串里出现过。
两个模式串不同当且仅当他们编号不同。
题解
AC 自动机,俗称在 Trie 上跑 KMP ,不能否定但也不能完全认同。
本篇题解并不详细,供己用,在板子里打了一点注释。
主要思路
AC自动机的精髓在于分情况记录的 (fail) 数组。由于 (fail) 数组中直接把不合法情况连向虚拟节点 (0) ,所以查找的时候不需要像 KMP 一样判断当前是否失配。
(fail) 的更新大概有两种情况,如下:
for(int i=0;i<26;i++)
{
if(trie[p][i]) fail[trie[p][i]]=trie[fail[p]][i],q.push(trie[p][i]);
//如果trie_p节点存在字符i的儿子, 当前点的fail可以用它父亲p的fail数组更新
else trie[p][i]=trie[fail[p]][i];
//如果trie_p节点不存在为字符i的儿子(失配),就直接连向fail[p]的字符i儿子
//如果fail[p]没有,那就连向了0
//这两条赋值语句后面的东西居然是一样的,是不是很神奇?
}
易错细节
(大概是初写板子的时候出的小坑)
- 容易忘记
q.push(trie[p][i]);
find
函数里忘记now=fail[now]
代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define inl inline
const int INF = 0x3f3f3f3f,N = 1e6+10;
inline ll read()
{
ll ret=0;char ch=' ',c=getchar();
while(!(c>='0'&&c<='9')) ch=c,c=getchar();
while(c>='0'&&c<='9') ret=(ret<<1)+(ret<<3)+c-'0',c=getchar();
return ch=='-'?-ret:ret;
}
int n,m;
//头一次封装写函数
struct AC
{
int trie[N][27],fail[N],tot=1;
int cnt[N]; bool vis[N];
queue<int> q;
inl void insert(char s[])
{
int p=1,len=strlen(s+1);
for(int i=1;i<=len;i++)
{
int ch=s[i]-'a';
if(!trie[p][ch]) trie[p][ch]=++tot;
p=trie[p][ch];
}
cnt[p]++;
}
inl void build()//BFS求出fail数组
{
for(int i=0;i<26;i++) trie[0][i]=1;//把位置0的儿子指向1(即0的儿子是根)
q.push(1);
while(!q.empty())
{
int p=q.front(); q.pop();
for(int i=0;i<26;i++)
{
if(trie[p][i]) fail[trie[p][i]]=trie[fail[p]][i],q.push(trie[p][i]);
//如果trie_p节点存在字符i的儿子, 当前点的fail可以用它父亲p的fail数组更新
else trie[p][i]=trie[fail[p]][i];
//如果trie_p节点不存在为字符i的儿子(失配),就直接连向fail[p]的字符i儿子
//如果fail[p]没有,那就连向了0
}
}
}
inl int find(char s[])
{
int p=1,len=strlen(s+1),ret=0;
for(int i=1;i<=len;i++)
{
int ch=s[i]-'a';
int now=trie[p][ch];
while(now)
{
if(vis[now]) break;
vis[now]=1;//每一个模式串只能计算一次
ret+=cnt[now];
now=fail[now];//不断跳fail数组,直到跳回0
}
p=trie[p][ch];
}
return ret;
}
inl void clear()
{
memset(trie,0,sizeof(trie));
memset(fail,0,sizeof(fail));
memset(cnt,0,sizeof(cnt));
memset(vis,0,sizeof(vis));
tot=1;
}
}ac;
char s[N];
int main()
{
int T=read();
while(T--)
{
ac.clear();
n=read();
for(int i=1;i<=n;i++)
{
scanf("%s",s+1);
ac.insert(s);
}
ac.build();
scanf("%s",s+1);
printf("%d
",ac.find(s));
}
return 0;
}