[NOI2018] 你的名字
该来的总归还是会来的……
题意:
一句话题意:给出一个字符串(T),(Q)组询问,每次询问字符串(S)中有多少个本质不同的子串在(T[l..r])中没有出现。(我们用(T[l..r])表示截取(T)中([l, r])区间内的字符得到的字符串)
题解:
真的是对(SAM)的科技一无所知……
首先,做出这道题需要对(SAM)了解的足够深。我们先从部分分开始分析,假设(l = 1, r = | T |),即每次都查询的是整个(T),我们该怎么做呢?考虑每个(S)中的每一个右端点(i),都会存在一个(l_i)作为一个分界线,即(forall j < l_i,)都有(S[j..i])不是(T)的子串,(forall j geq i,)都有(S[j..i])是(T)的子串,这个还是比较明显的。并且容易发现的是,这个(l_i)是非递减的。于是我们就可以用(two pointers)来解决这个问题。
我们记(L)表示(l_{i - 1}),(len)表示当前匹配串长,(o)表示当前区间子串(S[L..i - 1])在(T)的(SAM)中的位置,当(i)增大一时,我们尝试继续在(SAM)中转移,如果(SAM)中存在这条转移边,那么我们就把(o)移向这个转移节点,然后让匹配串长加一。否则我们不断的让匹配串长减一,即让(L)右移,然后继续判断(SAM)上是否存在转移边,直至(len = 0)或者存在转移边为止。注意当(len = t[t[o].fa].len)时,即当前串对应的节点已经是(o)的父亲节点时,需要移动(o)至父亲节点。这样我们就可以(O(n))的求出每个点的(l_i)了。但是如果直接用(l_i)计算答案,会发现一些相同的子串会被重复计算贡献。
考虑如何去重,我们建出了(SAM)之后,尝试用(SAM)进行去重,我们记在(SAM)上的每个节点多记一个(id),表示其右端点在(S)中的位置,那么由于(SAM)中的节点所表示的子串都是本质不同的,那么我们枚举(SAM)上的节点计算答案是不会重复计算贡献的。记(SAM)的节点个数为(tot),答案就是(sum_{i = 1} ^ { tot } max(0, t[i].len - max(t[t[i].fa].len, l[t[i].id])))。这个还是比较好理解的,由于(SAM)中的一个节点里,所有串都是该节点长度最长的串的一个后缀,那么根据我们对于(l_i)的定义,该节点中长度小于等于(l[t[i].id])的串都是(T)的子串,所以能够造成贡献的串就是长度在(l[t[i].id] + 1)到(t[i].len)之间的串,注意和父亲节点的(len)取个(max)。这样我们就可以得到(68)的好成绩。
接下来考虑如何处理(l eq 1, r eq | T |)的情况,分析一下,唯一不同的就是处理(l_i)的时候会出现问题。我们之前是如果存在转移边,那么就往转移边的方向走。但是实际情况是,这个转移边并不属于([l..r])这个区间的,也就是说这个转移实际上是不合法的。所以我们需要处理这一类不合法的转移。实际上只需要维护(SAM)的(right)集合即可,我们用线段树来记录一个节点的(right)集合。由于一个节点的(right)集合是其子节点的并,所以在建完(SAM)之后,用线段树合并将子节点的(right)集合合并至父亲节点的就行了。那么我们在存在转移边时,判断一下转移至的节点是的(right)集合是否有值在([l + len, r])这个区间内即可,其余的照做就行了。
Code:
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 50;
typedef long long ll;
int n, m, q;
char s[N], ss[N];
namespace Seg {
int tot;
int ls[N * 30], rs[N * 30];
void Insert(int &o, int l, int r, int p) {
if(!o) o = ++tot;
if(l == r) return ;
int mid = (l + r) >> 1;
if(mid >= p) Insert(ls[o], l, mid, p);
else Insert(rs[o], mid + 1, r, p);
}
int Merge(int x, int y) {
if(!x || !y) return x | y;
int o = ++tot;
ls[o] = Merge(ls[x], ls[y]);
rs[o] = Merge(rs[x], rs[y]);
return o;
}
int Query(int o, int l, int r, int L, int R) {
if(!o) return 0;
if(L <= l && r <= R) return 1;
int mid = (l + r) >> 1;
if(mid >= L && Query(ls[o], l, mid, L, R)) return 1;
if(mid < R && Query(rs[o], mid + 1, r, L, R)) return 1;
return 0;
}
}
namespace SAM {
struct node {
int ch[27];
int len, fa, rt, is;
}t[N << 1];
int tot = 1, las = 1;
int id[N << 1], sa[N << 1];
void Insert(int c) {
int p = las, np = ++tot;
t[np].len = t[p].len + 1; t[np].is = 1;
while(p && !t[p].ch[c]) t[p].ch[c] = np, p = t[p].fa;
if(!p) t[np].fa = 1;
else {
int q = t[p].ch[c];
if(t[q].len == t[p].len + 1) t[np].fa = q;
else {
int nq = ++tot;
t[nq] = t[q];
t[nq].len = t[p].len + 1; t[nq].is = 0;
t[q].fa = t[np].fa = nq;
while(p && t[p].ch[c] == q) t[p].ch[c] = nq, p = t[p].fa;
}
}
las = np;
}
void Build() {
for(int i = 1; i <= n; i++) Insert(s[i] - 'a');
for(int i = 1; i <= tot; i++) id[t[i].len] ++;
for(int i = 1; i <= n; i++) id[i] += id[i - 1];
for(int i = tot; i > 1; i--) sa[id[t[i].len] --] = i;
for(int i = tot; i > 1; i--) {
int p = sa[i];
if(t[p].is) Seg::Insert(t[p].rt, 1, n, t[p].len);
t[t[p].fa].rt = Seg::Merge(t[t[p].fa].rt, t[p].rt);
}
}
}
namespace Solver {
struct node {
int ch[27];
int fa, len, id;
}t[N << 1];
int tot = 1, las = 1;
int l[N << 1];
void Clear() {
for(int i = 0; i <= tot; i++) memset(t[i].ch, 0, sizeof t[i].ch), t[i].fa = t[i].len = t[i].id = 0;
tot = las = 1;
}
void Insert(int c) {
int p = las, np = ++tot;
t[np].len = t[p].len + 1; t[np].id = t[np].len;
while(p && !t[p].ch[c]) t[p].ch[c] = np, p = t[p].fa;
if(!p) t[np].fa = 1;
else {
int q = t[p].ch[c];
if(t[q].len == t[p].len + 1) t[np].fa = q;
else {
int nq = ++tot;
t[nq] = t[q];
t[nq].len = t[p].len + 1;
t[q].fa = t[np].fa = nq;
while(p && t[p].ch[c] == q) t[p].ch[c] = nq, p = t[p].fa;
}
}
las = np;
}
void main() {
Clear();
int L, R;
scanf("%s", ss + 1);
scanf("%d%d", &L, &R);
m = strlen(ss + 1);
int len = 0, o = 1;
for(int i = 1; i <= m; i++) {
int c = ss[i] - 'a';
Insert(c);
while(1) {
if(SAM::t[o].ch[c] && Seg::Query(SAM::t[SAM::t[o].ch[c]].rt, 1, n, L + len, R)) {
o = SAM::t[o].ch[c];
len ++;
break;
}
if(len == 0) break;
len--;
if(len == SAM::t[SAM::t[o].fa].len) o = SAM::t[o].fa;
}
l[i] = len;
}
ll ans = 0;
for(int i = 2; i <= tot; i++) {
ans += max(0, t[i].len - max(t[t[i].fa].len, l[t[i].id]));
}
printf("%lld
", ans);
}
}
int main() {
scanf("%s", s + 1);
n = strlen(s + 1);
SAM::Build();
scanf("%d", &q);
while(q--) Solver::main();
return 0;
}