题目链接: https://www.luogu.com.cn/problem/P4770
SAM好题.
(I)首先我们考虑l = 1,r = |S|的情况怎么做
我们要求的是本质不同的子串str的数量,满足str是T的子串,且str不是$S_{l,r}$的子串
容易用补集转化成T本质不同的子串数减去S和T本质不同子串数
第一个问题很平凡,我们考虑第二个问题
我们对S,T分别建自动机,令T在S上面跑匹配,同时按着S的跑法在T自己上面跑匹配(因为T的每个子串都为$SAM_{T}$所接受,所以一定能跑)
对于每个前缀我们都可以求出它和S的最长公共后缀l,及在T上的节点,容易发现这个节点以上的长度<=l的都是本质不同的公共子串,因为可能算重所以先打标记然后Treedp统计(这也是为什么要在T上跑的原因,
因为在S上面跑,每次都要遍历S的parent tree时间复杂度不对)
(II)接下来才是难点,如果l,r任意怎么做
显然对于每个子串都建后缀自动机是不可能的,我们思考我们这个后缀自动机到底干了什么呢?
1.判断有没有tran(p,c)的转移边.
2.判断p这个节点的maxlen和minlen
我们可以发现,只要用线段树合并维护出endpos集合,就可以完成区间的上诉两个问题.
int u = get(sam[p].ch[c],l + len,r);
if(u){
len++;
p = sam[p].ch[c];
x = sam[x].ch[c];
}
else{
while(len != -1 && !get(sam[p].ch[c],l + len,r)){
len--;
if(len == sam[sam[p].fa].len) p = sam[p].fa;
}
其中get(p,l,r)表示p这个节点的endpos集合在[l,r]范围内的最大值
设正在匹配的最长公共子串为s
我们发现我们原本要做的事情是判断s在p这个节点上能不能添上'c'这个字符,即判断 if(sam[p].ch[c] != 0),但是因为有区间限制我们应判断是否存在一个位置x可以接上s+'c',即在[l,r]区间内,是否存在一个endpos(x)满足x - len(s+'c') + 1>= l,即x >= l + len(s+'c') - 1也即x >= len(s) + l,于是只要判断[l+len,r]区间内是否存在endpos集合的元素即可
注意我们若失配此时不应该直接跳fa,而应该先让len自减,要记住这个后缀自动机只是一个框架,是$S_{1,n}$而不是$S_{l,r}$的SAM.
有人可能会问:怎么暴力while怎么能过?
因为数据水? 其实这个时间复杂度是正确的,我们考虑势能分析法,容易发现每次while,len最多减少1,外面for循环每次最多增加1,所以单次匹配时间复杂度是O(|T|logn)的
有很多细节,看代码吧
/*NOI2018[你的名字]*/
#include<bits/stdc++.h>
using namespace std;
#define ll long long
int read(){
char c = getchar();
int x = 0;
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - 48,c = getchar();
return x;
}
const int N = 2e6 + 10;
struct SegmentTree{
int lc,rc;
int mx;
}t[N<<4];/*线段树维护endpos集合*/
int Rt[N],num,n;
void pushup(int p){
t[p].mx = max(t[t[p].lc].mx,t[t[p].rc].mx);
}
void Insert(int &p,int l,int r,int pos){
if(!p) p = ++num;
if(l == r){
t[p].mx = max(t[p].mx,pos);
return;
}
int mid = (l + r) >> 1;
if(pos <= mid) Insert(t[p].lc,l,mid,pos);
else Insert(t[p].rc,mid+1,r,pos);
pushup(p);
}
int merge(int p,int q,int l,int r){
if(!p || !q) return p | q;
int u = ++num;
int mid = (l + r) >> 1;
t[u].lc = merge(t[p].lc,t[q].lc,l,mid);
t[u].rc = merge(t[p].rc,t[q].rc,mid + 1,r);
pushup(u);
return u;
}
int query(int p,int l,int r,int a,int b){
if(a <= l && b >= r) return t[p].mx;
int mid = (l + r) >> 1;
int ans = 0;
if(a <= mid) ans = max(ans,query(t[p].lc,l,mid,a,b));
if(b > mid) ans = max(ans,query(t[p].rc,mid+1,r,a,b));
return ans;
}
struct SAM{
int ch[26],len,fa;
}sam[N<<1];
int lst = 1,cnt = 1;
void ins(int c,int rt){
int p = lst,np = ++cnt;lst = np;
sam[np].len = sam[p].len + 1;
for(; !sam[p].ch[c]; p = sam[p].fa) sam[p].ch[c] = np;
if(!p) sam[np].fa = rt;
else{
int q = sam[p].ch[c];
if(sam[q].len == sam[p].len + 1) sam[np].fa = q;
else{
int nq = ++cnt;
sam[nq] = sam[q];
sam[nq].len = sam[p].len + 1;
sam[np].fa = sam[q].fa = nq;
for(; sam[p].ch[c] == q; p = sam[p].fa) sam[p].ch[c] = nq;
}
}
}
int head[N<<1];
int f[N<<1],tot;
struct Edge{
int nxt,point;
}edge[N<<1];
void add_edge(int u,int v){
edge[++tot].nxt = head[u];
edge[tot].point = v;
head[u] = tot;
}
char S[N],T[N];
void dfs(int u){
for(int i = head[u]; i ; i = edge[i].nxt){
int v = edge[i].point;
dfs(v);
f[u] = max(f[u],f[v]);
}
f[u] = min(f[u],sam[u].len);
}
void getpos(int u){
for(int i = head[u]; i ; i = edge[i].nxt){
int v = edge[i].point;
getpos(v);
Rt[u] = merge(Rt[u],Rt[v],1,n);
}
}
bool valid(int u,int len){
return len >= sam[sam[u].fa].len + 1 && len <= sam[u].len;
}
int get(int u,int l,int r){
if(l > r || !u) return 0;
return query(Rt[u],1,n,l,r);
}
int getlen(int u,int l,int r){
int x = get(u,l,r);
return min(sam[u].len,x - l + 1);
}
ll work(char *s,int rt,int l,int r){
int m = strlen(s+1);
int p = 1,len = 0,x = rt;
for(int i = rt + 1; i <= cnt; ++i){
add_edge(sam[i].fa,i);
}
for(int i = 1; i <= m; ++i){
int c = s[i] - 'a';
int u = get(sam[p].ch[c],l + len,r);
if(u){
len++;
p = sam[p].ch[c];
x = sam[x].ch[c];
}
else{
while(len != -1 && !get(sam[p].ch[c],l + len,r)){
len--;
if(len == sam[sam[p].fa].len) p = sam[p].fa;
}
if(len == -1){
p = 1;
len = 0;
x = rt;
}
else{
len++;
p = sam[p].ch[c];
while((!sam[x].ch[c] || !valid(sam[x].ch[c],len)) && x) x = sam[x].fa;
if(!x) x = rt;
x = sam[x].ch[c];
}
}
// cout<<i<<' '<<len<<endl;
f[x] = max(f[x],len);
}
dfs(rt);
ll ans = 0;
for(int i = rt + 1; i <= cnt; ++i){/*!!!attention*/
if(f[i] > sam[sam[i].fa].len){
// assert(f[i] > sam[sam[i].fa].len);
ans += f[i] - sam[sam[i].fa].len;
}
}
for(int i = rt; i <= cnt; ++i) f[i] = 0;
return ans;
}
int main(){
freopen("name.in","r",stdin);
freopen("name.out","w",stdout);
scanf("%s",S+1);
n = strlen(S+1);
for(int i = 1; i <= n; ++i){
ins(S[i]-'a',1);
Insert(Rt[lst],1,n,i);
}
for(int i = 2; i <= cnt; ++i){
add_edge(sam[i].fa,i);
}
getpos(1);
int q = read();
while(q--){
scanf("%s",T+1);
int l = read(),r = read();
int m = strlen(T+1);
int rt = ++cnt;
lst = rt;
for(int i = 1; i <= m; ++i){
ins(T[i]-'a',rt);
}
ll ans = 0;
for(int i = rt + 1; i <= cnt; ++i){
ans += sam[i].len - sam[sam[i].fa].len;
}
ans -= work(T,rt,l,r);
printf("%lld
",ans);
}
return 0;
}