1、什么是后缀自动机
详细知识参考以下$blog$:
https://www.cnblogs.com/zjp-shadow/p/9218214.html
2、后缀自动机构造
仍然参考上文$blog$
3、后缀自动机例题:
最长公共子串系列:
SPOJ LCS
题意:
求两个长度是$250000$的串的最长公共子串。
题解:
我们知道,后缀自动机从任意结点开始经转移函数到任意结点停止,形成的字符串都是这个字符串的子串。所以我们先把一个字符串建成$SAM$,然后对于另一个字符串。我们从第一个字符开始,在$SAM$上匹配。假设现在到达某个状态,匹配的长度是$len$,如果加上匹配串的下一个字符,其有对应的状态可以转移,则转移,同时$len+=1$,如果没有,我们只能舍弃一些前缀,所以沿$link$函数跳转,如果跳转到初始状态,则说明加上一个字符后,匹配串已经匹配的部分和原串没有公共子串,则$len=0$,如果没到,那么$len = longest[curst]+1$,同时转移,因为跳后缀的时候,匹配的部分一定是转移前这个状态的最长子串,再加上增加的一个字符。直到字符串匹配完。这里实际上是尺取的思想。
AC代码:
#include <bits/stdc++.h> using namespace std; const int N = 5e5 + 5, M = 26; struct SAM { int nxt[N][M], link[N], maxl[N]; int last, size, root; SAM() : last(1), size(1), root(1) {} void extend(int ch) { ch -= 'a'; int cur = (++size), p; maxl[cur] = maxl[last] + 1; for (p = last; p && !nxt[p][ch]; p = link[p]) nxt[p][ch] = cur; if (!p) link[cur] = root; else { int q = nxt[p][ch]; if (maxl[q] == maxl[p] + 1) link[cur] = q; else { int tmp = ++size; maxl[tmp] = maxl[p] + 1; for (int i = 0; i < M; ++i) nxt[tmp][i] = nxt[q][i]; link[tmp] = link[q]; for (; nxt[p][ch] == q; p = link[p]) nxt[p][ch] = tmp; link[cur] = link[q] = tmp; } } last = cur; } int lcs(char *s) { int ans = 0, now = root, len = 0; for (int i = 0; s[i]; ++i) { int ch = s[i] - 'a'; if (nxt[now][ch]) now = nxt[now][ch], ++len; else { while (now && !nxt[now][ch]) now = link[now]; if (!now) { len = 0; now = root; } else { len = maxl[now] + 1; now = nxt[now][ch]; } } ans = max(ans, len); } return ans; } }; SAM sam; char s1[N], s2[N]; int main() { scanf("%s%s", s1, s2); for (int i = 0; s1[i]; ++i) sam.extend(s1[i]); printf("%d ", sam.lcs(s2)); return 0; }
SPOJ LCS2
题意:
求最多$10$个子串的最长公共子串,单个串最长$1e5$。
题解:
上一题我们做了两个串。多个串的话,就需要保存一些信息了,首先我们需要知道一个性质,如果一个状态被匹配到,那么它的$parent树$的祖先一定会被匹配到。这样子,我们需要记录一下其余的模板串匹配到某个状态时的最大匹配长度,然后沿着$link$函数上推这个最小的这大匹配长度,显然这个是所有匹配串经过这个状态时的最大匹配长度,然后再对这些最小值取最大,即为所求。
问题来了,怎么上推?
对$link$函数反向建树即可,但是这样子实现的常数比较大,有没有更快的方法?
有的!
对这些状态基数排序!然后求出长度对应的排名。(这个排名就是后缀数组!)
然后我们上推的时候就直接排名从后往前上推即可(实际这就是$link$函数构成的$DAG$拓扑排序的过程)
AC代码:
#include <bits/stdc++.h> using namespace std; const int N = 2e5 + 10; const int M = 26; char s[N]; struct SAM { int nxt[N][M], link[N], maxl[N]; int last, root, size, n; SAM() : last(1), size(1), root(1) {} void extend(int ch) { ++n; ch -= 'a'; int cur = (++size), p; maxl[cur] = maxl[last] + 1; for (p = last; p && !nxt[p][ch]; p = link[p]) nxt[p][ch] = cur; if (!p) link[cur] = root; else { int q = nxt[p][ch]; if (maxl[q] == maxl[p] + 1) link[cur] = q; else { int tmp = ++size; maxl[tmp] = maxl[p] + 1; for (int i = 0; i < M; ++i) nxt[tmp][i] = nxt[q][i]; link[tmp] = link[q]; for (; nxt[p][ch] == q; p = link[p]) nxt[p][ch] = tmp; link[cur] = link[q] = tmp; } } last = cur; } int c[N], rnk[N]; void pre() { for (int i = 1; i <= size; ++i) ++c[maxl[i]]; for (int i = 1; i <= n; ++i) c[i] += c[i - 1]; for (int i = 1; i <= size; ++i) rnk[c[maxl[i]]--] = i; } int maxn[N], minn[N]; int work() { pre(); for (int i = 1; i <= size; ++i) minn[i] = maxl[i]; while (~scanf("%s", s)) { memset(maxn, 0, sizeof(maxn)); int now = root, len = 0; for (int i = 0; s[i]; ++i) { int ch = s[i] - 'a'; if (nxt[now][ch]) ++len; else { while (now && !nxt[now][ch]) now = link[now]; if (!now) len = 0; else len = maxl[now] + 1; } now = now ? nxt[now][ch] : root; maxn[now] = max(maxn[now], len); } for (int i = size; i; --i) { int x = rnk[i], fa = link[x]; maxn[fa] = max(maxn[fa], maxn[x]); } for (int i = 1; i <= size; ++i) minn[i] = min(minn[i], maxn[i]); } int ans = 0; for (int i = 1; i <= size; ++i) ans = max(ans, minn[i]); return ans; } }; SAM sam; int main() { scanf("%s", s); for (int i = 0; s[i]; ++i) sam.extend(s[i]); printf("%d ", sam.work()); return 0; }
最小(最大)循环串系列:
洛谷P1368 工艺
题意:
求一个环形串的断开位置,使得断开之后,这个串的字典序是所有断开位置中最小的。
题解:
这类题有一个特殊的解法,叫最小(最大)表示法,详见oi-wiki,但是我们今天就不用,用$SAM$搞定这个题。首先我们要构造出一个包含所有循环串的字符串,显然只要两个这样的字符串接在一起就可以了,然后我们构建$SAM$。然后贪心地转移即可,转移次数为串长。注:由于这里是数字,不是小写字母,所以就不能直接开数组,会$MLE$,同时要知道字典序最小的转移,使用$map$存状态即可。
AC代码:
#include <bits/stdc++.h> using namespace std; const int N = 1.2e6 + 5; struct SAM { map<int, int> nxt[N]; int link[N], maxl[N]; int last, size, root; SAM() : last(1), size(1), root(1) {} void extend(int val) { int cur = ++size, p; maxl[cur] = maxl[last] + 1; for (p = last; p && nxt[p].find(val) == nxt[p].end(); p = link[p]) nxt[p][val] = cur; if (!p) link[cur] = root; else { int q = nxt[p][val]; if (maxl[q] == maxl[p] + 1) link[cur] = q; else { int tmp = ++size; maxl[tmp] = maxl[p] + 1; for (auto i : nxt[q]) nxt[tmp].insert(i); link[tmp] = link[q]; for (;; p = link[p]) { auto it = nxt[p].find(val); if (it != nxt[p].end() && (*it).second == q) nxt[p].erase(it), nxt[p][val] = tmp; else break; } link[cur] = link[q] = tmp; } } last = cur; } void get_ans(int n) { int now = 1; for (int i = 1; i <= n; ++i) { printf("%d ", (*nxt[now].begin()).first); now = (*nxt[now].begin()).second; } printf(" "); } }; SAM sam; int s[N]; int main() { int n; scanf("%d", &n); for (int i = 1; i <= n; ++i) scanf("%d", &s[i]); for (int j = 1; j <= 2; ++j) for (int i = 1; i <= n; ++i) sam.extend(s[i]); sam.get_ans(n); return 0; }
子串数量和子串排名系列:
洛谷P3804【模板】后缀自动机
题意:
求出出现次数不为$1$的子串中的出现次数和子串长度的乘积的最大值。
题解:
子串的出现次数,就是其$endpos$集合的大小,怎么求呢?由于相同的$endpos$属于同一个状态,所以某个状态的子串的数量,就是这个状态中对应的子串的出现次数,且所有子串的出现次数是相同的。所以问题转化成求出某状态的$endpos$的大小。因为$parent$树中,祖先是某结点的后缀,且兄弟之间不构成后缀。所以兄弟的$endpos$交集是空集,所以满足加法。结果就是以这个结点为根的子树大小。求子树问题的,直接按照上文的方法求拓扑序上推即可。
$AC$代码是建出树然后$dfs$,不影响结果。
AC代码:
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int N = 2e6 + 5; const int M = 26; vector<int> G[N]; int sz[N]; ll ans = 0; struct SAM { int maxlen[N], link[N], size, last; int nxt[N][M]; SAM() : size(1), last(1) {} void extend(int ch) { ch -= 'a'; int cur = (++size), p; sz[size] = 1; maxlen[cur] = maxlen[last] + 1; for (p = last; p && !nxt[p][ch]; p = link[p]) nxt[p][ch] = cur; if (!p) link[cur] = 1; else { int q = nxt[p][ch]; if (maxlen[q] == maxlen[p] + 1) link[cur] = q; else { int tmp = ++size; maxlen[tmp] = maxlen[p] + 1; for (int i = 0; i < M; ++i) nxt[tmp][i] = nxt[q][i]; link[tmp] = link[q]; for (; p && nxt[p][ch] == q; p = link[p]) nxt[p][ch] = tmp; link[cur] = link[q] = tmp; } } last = cur; } }; SAM sam; void dfs(int u) { for (auto i : G[u]) { dfs(i); sz[u] += sz[i]; } if (sz[u] != 1 && u > 1) ans = max(ans, 1ll * sz[u] * sam.maxlen[u]); } char s[N]; int main() { scanf("%s", s + 1); int len = strlen(s + 1); for (int i = 1; i <= len; ++i) sam.extend(s[i]); for (int i = 2; i <= sam.size; ++i) G[sam.link[i]].push_back(i); dfs(1); printf("%lld ", ans); return 0; }
$parent$树系列:
洛谷P4248 [AHOI2013]差异
题意:
求 $sum _{1 leq i < j leq n} len(T_i) + len(T_j) - 2 * lcp(T_i, T_j)$,其中$lcp$为最长公共前缀,$T_i$是原串以$i$为左端点的后缀。
题解:
我们由上题$SPOJ$ $LCS$可以知道,$SAM$可以求前缀的最长公共后缀。现在需要求后缀的最长公共前缀,则只需要将字符串反着建即可。两个后缀的$lcp$实际上就是$parent$树上对应的结点的$lca$,考虑这个题是$sum _{1 leq i<j leq n} lcp(T_i, T_j)$,且编号大的结点对应的后缀越长。所以,我们只需要从编号大的结点开始算贡献即可。它自己是$i$,它的已经被统计的结点为$j$,对于一个点,它的贡献就是$size[link[i]]*size[i]*longest[link[i]]$。($link[i]$实际上是$i$在$parent$树上的父亲节点,所以下文称为父亲)
如果这个父亲还没有被累加,其$size$可能是$0$或者$1$,我们只考虑$1$的。对于一个点,贡献的公式的意义就是$link[i]$这个点,和下面子树结点的$lcp$都是$link[i]$这个结点的$longest$所以可以这么写,然后把$size[i]$的影响累加到$size[link[i]]$之中。如果被累加了,则这个是已经被统计$j$的数量上推到$link[i]$中。显然这个$link[i]$就是$i$和$j$的$lcp$,数量就是$size[i]*size[link[i]]$。
AC代码:
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int N = 1e6 + 5; const int M = 26; struct SAM { int nxt[N][M], link[N], maxl[N]; int last, root, size; int sz[N]; SAM() : last(1), root(1), size(1) {} void extend(int ch) { ch -= 'a'; int cur = ++size, p; maxl[cur] = maxl[last] + 1; sz[size] = 1; for (p = last; p && !nxt[p][ch]; p = link[p]) nxt[p][ch] = cur; if (!p) link[cur] = root; else { int q = nxt[p][ch]; if (maxl[q] == maxl[p] + 1) link[cur] = q; else { int tmp = ++size; maxl[tmp] = maxl[p] + 1; for (int i = 0; i < M; ++i) nxt[tmp][i] = nxt[q][i]; link[tmp] = link[q]; for (; nxt[p][ch] == q; p = link[p]) nxt[p][ch] = tmp; link[q] = link[cur] = tmp; } } last = cur; } int c[N], id[N]; void sort() //基数排序 { for (int i = 1; i <= size; ++i) ++c[maxl[i]]; for (int i = 1; i <= size; ++i) c[i] += c[i - 1]; for (int i = 1; i <= size; ++i) id[c[maxl[i]]--] = i; } ll get_ans() { sort(); sz[1] = 1; ll ans = 0; for (int i = size; i; --i) { int cur = id[i]; ans += 1ll * sz[link[cur]] * sz[cur] * maxl[link[cur]]; sz[link[cur]] += sz[cur]; } return ans; } }; SAM sam; char s[N]; int main() { scanf("%s", s); int len = strlen(s); for (int i = 0; i < len; ++i) sam.extend(s[i]); printf("%lld ", 1ll * len * (len - 1) * (len + 1) / 2 - 2 * sam.get_ans()); return 0; }