AC自动机学习笔记
前言
由于太菜了, 不知道怎么解释自动机到底是个什么东西
所以大佬轻 D
正篇开始
本文中字符串下标从 1 开始
我们知道 KMP 是用来解决一个模式串在一个母串上匹配的问题的
那多个模式串在一个母串上匹配怎么做呢?
可以考虑 AC自动机
用这些模式串先构出一棵 Trie 树, 然后再在 Trie 树上跑类似于 KMP 的东西
我们可以考虑把一些标记打在 Trie 树上的点来满足题目的询问
好的, 那么我们现在有了一棵由模式串构成的 Trie 树
这里先要引入一个概念, fail 指针
何为 fail 指针
我们假设当前匹配到的是母串中的 (S[i - j]) , 也就是母串中 (i) 到 (j) 的一个子串
然后跟这个串匹配的是模式串 (k_1) 的 (C[1 - (j - i + 1)])
如果 (S[j + 1]) 和 (C[j - i + 2]) 相等, 我们就可以继续匹配
那如果不匹配呢?
如果我们找到一个其他的模式串 (k_2), 满足这个模式串的前缀和 (k_1) 的一个真后缀相等, 并且在所有的满足这个条件的 (k_i) 中, 这个相等的子串的长度最长
即 (k_2) 的 (C[1 - p_1]) 和 (k_1) 的 (C[(j - i + 2 - p_1) - (j - i + 1)]) 相等, 并且找不到一个 (k_3) 满足 (k_3) 中最大的匹配位置 (p_2 > p_1)
那么 (k_1) 的(C[p_1]) 在 Trie 树中的位置, 就是 (k_1) 的 (C[j - i + 1]) 这个点的 fail 指向的点
那我们为什么要引入 fail 指针的一个概念呢?
我们知道, 假如说我们在这个点无法匹配时, 看在这个点的 fail , 能不能匹配
又由于没有比这个 fail 能够匹配更长最长真后缀的, 于是在你跳 fail 的时候不会漏掉可能合法的答案
如果你有一个字符串的前缀能够匹配这个字符, 那这个前缀对应的这个点, 一定可以通过不断跳当前点的 fail 得到
我们可以将跳 fail 的过程看作不断选择这个模式串的后缀, 砍掉前面最短的一部分, 满足其他的在这棵 Trie 树中存在
是不是有一点像 KMP 的不断跳 (nxt)
给段代码理解一下
void get_fail()
{
for(int i = 0; i < 4; i++)
if(t[0].ch[i]) t[t[0].ch[i]].fail = 0, q.push(t[0].ch[i]);
//我们要先排除一个点的 fail 是他的父亲的情况, 由于你 fail 指的是一个最长真后缀, 你当前只有一个字符, 至少砍掉一个不就是没了
while(!q.empty())
{
int u = q.front(); q.pop();
for(int v, i = 0; i < 4; i++)
{
v = t[u].ch[i];
if(v) t[v].fail = t[t[u].fail].ch[i], q.push(v); //存在 u 这个儿子 v , 那这个儿子 v 的 fail 指向的, 就是 u 的 fail 对应的这个点的对应的这个儿子, 可以看作在这两个字符串后面同时加了一个字符
else t[u].ch[i] = t[t[u].fail].ch[i]; //不存在这个儿子 v , 就把 u 的这个儿子指向 u 的 fail 的这个儿子, 这一步并不会影响到 fail 的正确性, 相当于是我们匹配不到最长的就匹配次长的, 然后这个 u 的儿子指到 fail 的儿子相当于你先跳到 fail , 然后跳到 fail 的这个儿子, 事实上是两步, 我们把它并成一步
}
}
}
然后我们就把这个 AC自动机构出来了
那么接下来就直接拿母串在 AC自动机上匹配即可
那么如何维护题目所求的东西呢
我们知道到 u 点可以匹配, 那到 u 的 fail 也必然可以匹配
假如说我们要维护的是是否存在, 那么我们可以不断跳 fail 不断标记, 直到当前这个 fail 被标记过了就可以停止了
这样每个点只会被更新一次, 复杂度是对的
但如果我们要维护的是匹配次数呢
这样不断跳 fail , 必须要跳到底
如果是 (aaaaaaaaaaaaaaaaaaaaaaa) 这种串, 复杂度就不对了
那要怎么做呢
考虑每一次每个点的 fail , 只有一个, 假如说我们把每个点和他的 fail 看作树中的一对父子关系
那么每次跳 fail 可以看作在这棵 fail 树中, 从这个点不断跳他的父亲, 直到跳到根
那么对于一个点, 他被更新的次数是不是就是他子树中所有的点的总次数呢?
所以我们就可以, 在母串匹配的时候只在当前点打上标记, 然后把 fail 树建出来, 树形 DP 一下就行了
差不多就到这里吧, 还有什么东西之后再补充
贴个完整代码吧
Code1
这个是不停的跳 fail 的代码
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
const int N = 1e6 + 5;
using namespace std;
int n, cnt;
char s[N];
struct node { int fail, cnt, ch[26]; } t[N];
queue<int> q;
template < typename T >
inline T read()
{
T x = 0, w = 1; char c = getchar();
while(c < '0' || c > '9') { if(c == '-') w = -1; c = getchar(); }
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * w;
}
void build()
{
int len = strlen(s + 1);
int u = 0;
for(int tmp, i = 1; i <= len; i++)
{
tmp = s[i] - 'a';
if(!t[u].ch[tmp]) t[u].ch[tmp] = ++cnt;
u = t[u].ch[tmp];
}
t[u].cnt++;
}
void get_fail()
{
for(int i = 0; i < 26; i++)
if(t[0].ch[i]) t[t[0].ch[i]].fail = 0, q.push(t[0].ch[i]);
while(!q.empty())
{
int u = q.front(); q.pop();
for(int v, i = 0; i < 26; i++)
{
v = t[u].ch[i];
if(v)
{
t[v].fail = t[t[u].fail].ch[i];
q.push(v);
}
else t[u].ch[i] = t[t[u].fail].ch[i];
}
}
}
int solve()
{
int u = 0, len = strlen(s + 1), ans = 0;
for(int v, tmp, i = 1; i <= len; i++)
{
v = u = t[u].ch[tmp = s[i] - 'a'];
while(v && t[v].cnt != -1)
{
ans += t[v].cnt;
t[v].cnt = -1;
v = t[v].fail;
}
}
return ans;
}
int main()
{
n = read <int> ();
for(int i = 1; i <= n; i++)
{
scanf("%s", s + 1);
build();
}
scanf("%s", s + 1);
get_fail();
printf("%d
", solve());
return 0;
}
Code2
这是构 fail 树的代码
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
const int N = 2e6 + 5;
using namespace std;
int n, cnt, cnte, sz[N], head[N], match[N];
char s[N];
struct node { int fail, ch[26]; } t[N];
struct edge { int to, nxt; } e[N];
queue<int> q;
template < typename T >
inline T read()
{
T x = 0, w = 1; char c = getchar();
while(c < '0' || c > '9') { if(c == '-') w = -1; c = getchar(); }
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * w;
}
inline void adde(int u, int v) { e[++cnte] = (edge) { v, head[u] }, head[u] = cnte; }
void get_fail()
{
for(int i = 0; i < 26; i++)
if(t[0].ch[i]) t[t[0].ch[i]].fail = 0, q.push(t[0].ch[i]);
while(!q.empty())
{
int u = q.front(); q.pop();
for(int v, i = 0; i < 26; i++)
{
if(v = t[u].ch[i])
t[v].fail = t[t[u].fail].ch[i], q.push(v);
else t[u].ch[i] = t[t[u].fail].ch[i];
}
}
}
void solve()
{
int u = 0, len = strlen(s + 1);
for(int tmp, i = 1; i <= len; i++)
{
u = t[u].ch[tmp = s[i] - 'a'];
sz[u]++;
}
}
void dfs(int u)
{
for(int v, i = head[u]; i; i = e[i].nxt)
dfs(v = e[i].to), sz[u] += sz[v];
}
int main()
{
n = read <int> ();
for(int u = 0, len, tmp, i = 1; i <= n; i++, u = 0)
{
scanf("%s", s + 1), len = strlen(s + 1);
for(int j = 1; j <= len; j++)
{
if(!t[u].ch[tmp = s[j] - 'a']) t[u].ch[tmp] = ++cnt;
u = t[u].ch[tmp];
}
match[i] = u;
}
scanf("%s", s + 1);
get_fail(), solve();
for(int i = 1; i <= cnt; i++)
adde(t[i].fail, i);
dfs(0);
for(int i = 1; i <= n; i++)
printf("%d
", sz[match[i]]);
return 0;
}