比赛链接:https://www.luogu.com.cn/contest/35817#problems
题目链接:https://www.luogu.com.cn/problem/U135768?contestId=35817
出题人题解链接:https://mivik.blog.luogu.org/mivik-string-open-contest-solution-book
思路
以某个字符串结尾,即在其广义后缀自动机的fail树上子树所有节点长度和。
AC代码
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int MAXN = 2e6 + 5;
const int MAXC = 26;
class Suffix_Automaton {
public:
int rt, link[MAXN], maxlen[MAXN], trans[MAXN][MAXC];
int val[MAXN]; // 用于统计某一串出现的次数
void init() {
rt = 1;
link[1] = maxlen[1] = 0;
memset(trans[1], 0, sizeof(trans[1]));
}
Suffix_Automaton() { init(); }
inline int insert(int ch, int last) { // main: last = 1
if (trans[last][ch]) {
int p = last, x = trans[p][ch];
if (maxlen[p] + 1 == maxlen[x]) { // 特判1:这个节点已经存在于SAM中
val[x]++; // 统计在整颗字典树上出现次数
return x;
} else {
int y = ++rt;
maxlen[y] = maxlen[p] + 1;
for (int i = 0; i < MAXC; i++) trans[y][i] = trans[x][i];
while (p && trans[p][ch] == x) trans[p][ch] = y, p = link[p];
link[y] = link[x], link[x] = y;
val[y]++; // 统计在整颗字典树上出现次数
return y;
}
}
int z = ++rt, p = last;
val[z] = 1; // 统计在整颗字典树上出现次数
memset(trans[z], 0, sizeof(trans[z]));
maxlen[z] = maxlen[last] + 1;
while (p && !trans[p][ch]) trans[p][ch] = z, p = link[p];
if (!p) link[z] = 1;
else {
int x = trans[p][ch];
if (maxlen[p] + 1 == maxlen[x]) link[z] = x;
else {
int y = ++rt;
maxlen[y] = maxlen[p] + 1;
for (int i = 0; i < MAXC; i++) trans[y][i] = trans[x][i];
while (p && trans[p][ch] == x) trans[p][ch] = y, p = link[p];
link[y] = link[x], link[z] = link[x] = y;
}
}
return z;
}
ll sum[MAXN];
struct Edge {
int to, nex;
} e[MAXN << 1];
int head[MAXN], tol;
void addEdge(int u, int v) {
e[tol].to = v;
e[tol].nex = head[u];
head[u] = tol;
tol++;
}
void dfs(int u) {
for (int i = head[u]; ~i; i = e[i].nex) {
int v = e[i].to;
dfs(v);
sum[u] += sum[v];
}
sum[u] += maxlen[u] - maxlen[link[u]];
}
int build() {
tol = 0;
for (int i = 0; i <= rt; i++) head[i] = -1;
for (int i = 2; i <= rt; i++) addEdge(link[i], i); // 建fail树
dfs(1);
}
void debug(int u) {
for (int i = 0; i < 26; i++) {
if (trans[u][i]) {
printf("%d %c %d
", u, i + 'a', trans[u][i]);
debug(trans[u][i]);
}
}
}
ll qwq(char s[], int slen) {
int u = 1;
for (int i = 1; i <= slen; i++) {
int ch = s[i] - 'a';
if (!trans[u][ch])return 0;
u = trans[u][ch];
}
return sum[u] - (maxlen[u] - maxlen[link[u]]) + maxlen[u] - slen + 1;
}
} sa;
char str[MAXN], s[MAXN];
int main() {
scanf("%s", str + 1);
int len = strlen(str + 1);
sa.init();
int last = 1;
for (int i = 1; i <= len; i++) {
last = sa.insert(str[i] - 'a', last);
}
sa.build();
// sa.debug(1);
int T;
scanf("%d", &T);
while (T--) {
scanf("%s", s + 1);
int slen = strlen(s + 1);
printf("%lld
", max((ll) 0, sa.qwq(s, slen)));
}
}