题面
https://www.luogu.com.cn/problem/P4770
题解
前置知识:
题目多次给定T,询问字符串T中,有多少个不同的子串不与S[l..r]的任何一个子串相等。
首先建出S的后缀自动机,并预处理出fail树上的(2^k)代祖先。
对于每次询问,如果我们将T在S的后缀自动机上跑,就可以求出(T[1..1],T[1..2],…,T[1..|T|])在S中的最长匹配长度(lcp[1],lcp[2],…,lcp[|T|]),以及该匹配对应的SAM节点(loc[1],loc[2],…loc[|T|])。然后求出T的“前缀数组”,枚举题目所求的字符串的右端点(cur_r)。
考虑T中,以(cur_r)为右端点且不与(S[l..r])的任何一个子串相等的子串有哪些。发现它们的左端点一定是从1开始连续的一段(若(T[lim..cur_r])不合法,那么(T[lim+1..cur_r],T[lim+2..cur_r]…T[cur_r..cur_r])都不符合条件)。先不考虑细节((T[1..cur_r])到(T[cur_r..cur_r])全都合法以及全都不合法),一定存在(lim)使得(T[lim+1..cur_r])不合法而(T[lim..cur_r])合法,那么我们可以二分查找lim。
先不谈怎么二分查找,先以(T[cur_r-lcp[cur_r]+1..cur_r])为例,说一下怎么判断合法性。我们知道字符串S的SAM的fail树就是(S^R)的后缀树。以下只画出SAM的fail树。假设(S[1..cur_r])在SAM上所能匹配到的最长部分即(T[cur_r-lcp[cur_r]+1..cur_r])是图中橙色部分,蓝色点是(loc[cur_r])。
如果(T[cur_r-lcp[cur_r]+1..curr])不合法,假设它在S中的(S[l_0..r_0](l{leq}l_0{leq}r_0{leq}r,r_0-l_0+1=lcp[cur_r]))出现了。那么(S[1..r_0])对应的节点(粉色箭头的开头)一定在蓝色点的子树中(因为(S[1..r_0])是(S[l_0..r_0])的前缀)。
因此可以看出,(T[cur_r-lcp+1..curr])不合法的充要条件就是,在(loc)的子树内存在一个点(r_0),使得(l{leq}r_0-lcp+1{leq}r_0{leq}r)。移项得到(l+lcp-1{leq}r_0{leq}r)。这就成了一个在线二维数点问题,可以用主席树解决。
二分查找就基于这一算法。设置两个变量,u和len。初始时,将(u)赋值为(loc[cur_r]),(len)赋值为(lcp[cur_r])。接下来,由于我们已预处理了此树上每一个点的(2^k)代祖先,可以倍增地尝试:
for(int i = 20;i >= 0;i--){
if(fail[u][i] == -1)continue;
if(legal(fail[u][i],S.len[fail[u][i]]))u = fail[u][i],len = S.len[u];
}
- 这里的S.len[u]表示的是SAM节点u表示的字符串中最长的那个的长度,建SAM时可以一并算出的。
- legal是判断是否合法的函数
由此算出的u就满足,我们要找的lim(还记得lim是什么吗www)一定在u到(fail[u][0])的线段上(u含,(fail[u][0])不含)
接下来再进行一次二分查找,这次找的是lim具体在u到(fail[u][0])的线段上的哪个位置。
int L = S.len[fail[u][0]] + 1,R = len;
while(L < R){
int mid = (L + R) >> 1;
if(legal(u,mid))R = mid;
else L = mid + 1;
}
最后(cur_r-L+1)就是我们要找的lim啦。
注意要对所有(1{leq}cur_r{leq}len_t)的(cur_r)都求一遍lim,所以处理一个询问的时间是(O(|T| log |T|))
统计答案时还需要注意一个地方:我们求的答案是不能重复的。这和用SA求不同子串个数是一样的,height之后的部分是重复的,不能算进去。
for(int i = 1;i <= len_t;i++){
int cur_r = sa[i];//其实应该叫pa,因为是前缀数组www
//ans += lim[cur_r];-------------------------------- wrong!!!
ans += min(lim[cur_r],curr - h[i]);//--------------- right √
}
还有一些细节见代码吧。
总时间复杂度(O(|S| log |S|+sum|T| log |T| ))。
代码
P.S.对于这类(sum|T|=1e6)之类的题目,每次询问千万不能随手memset到底,时间复杂度立刻就不对了
#include<bits/stdc++.h>
using namespace std;
#define rg register
#define ll long long
#define In inline
const int LT = 1e6;
const int LS = 5e5;
const int TN = 2 * LS * 20;
In int read(){
int s = 0,ww = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
return s * ww;
}
In void write(ll x){
if(x < 0)putchar('-'),x = -x;
if(x > 9)write(x / 10);
putchar('0' + x % 10);
}
int ls;
struct CMTree{ //按dfn排序,按编号询问的主席树
int rt[2*LS+5],c[TN+5][2],num[TN+5];
int cnt,rn;
In void pushup(int u){
num[u] = num[c[u][0]] + num[c[u][1]];
}
void ud(int u1,int u2,int l,int r,int x){
if(l == r){
num[u2]++;
return;
}
int m = (l + r) >> 1;
if(x <= m){
int v = ++cnt;
num[v] = num[c[u1][0]];
c[u2][0] = v,c[u2][1] = c[u1][1];
ud(c[u1][0],c[u2][0],l,m,x);
}
else{
int v = ++cnt;
num[v] = num[c[u1][1]];
c[u2][1] = v,c[u2][0] = c[u1][0];
ud(c[u1][1],c[u2][1],m + 1,r,x);
}
pushup(u2);
}
In void insert(int x){
rt[++rn] = ++cnt;
ud(rt[rn-1],rt[rn],0,ls,x);
}
int query(int u1,int u2,int l,int r,int ql,int qr){
if(l == ql && r == qr)return num[u2] - num[u1];
int m = (l + r) >> 1;
if(qr <= m)return query(c[u1][0],c[u2][0],l,m,ql,qr);
else if(ql > m)return query(c[u1][1],c[u2][1],m + 1,r,ql,qr);
else return query(c[u1][0],c[u2][0],l,m,ql,m) + query(c[u1][1],c[u2][1],m + 1,r,m + 1,qr);
}
In int sum(int dfnl,int dfnr,int ql,int qr){ //查询dfn在[dfnl,dfnr],数值在[ql,qr]中的数u有多少个
return query(rt[dfnl-1],rt[dfnr],0,ls,ql,qr);
}
}T;
int loc[LT+5],lcp[LT+5]; //lcp[i]表示t[1..i]在SAM中匹配到的最长长度,loc表示最长匹配对应的SAM节点
char s[LS+5],t[LT+5];
int lt,l,r;
int dfn[2*LS+5],sz[2*LS+5],dn;
struct SAM{
int nx[2*LS+5][26],fail[2*LS+5][21],len[2*LS+5],flag[2*LS+5];
int cnt,last;
void init(){
fail[0][0] = -1;
}
void extend(char c,int n){
int id = c - 'a';
int cur = ++cnt,p;
flag[cur] = n;
len[cur] = len[last] + 1;
for(p = last;p != -1 && !nx[p][id];p = fail[p][0])nx[p][id] = cur;
if(p == -1)fail[cur][0] = 0;
else{
int q = nx[p][id];
if(len[q] == len[p] + 1)fail[cur][0] = q;
else{
int clone = ++cnt;
len[clone] = len[p] + 1;
fail[clone][0] = fail[q][0];
memcpy(nx[clone],nx[q],sizeof(nx[clone]));
fail[cur][0] = fail[q][0] = clone;
for(;p != -1 && nx[p][id] == q;p = fail[p][0])nx[p][id] = clone;
}
}
last = cur;
}
struct edge{
int next,des;
}e[4*LS+5];
int head[2*LS+5],Cnt;
In void addedge(int a,int b){
Cnt++;
e[Cnt].des = b;
e[Cnt].next = head[a];
head[a] = Cnt;
}
void dfs(int u){
dfn[u] = ++dn;
sz[u] = 1;
T.insert(flag[u]);
for(rg int i = head[u];i;i = e[i].next){
int v = e[i].des;
dfs(v);
sz[u] += sz[v];
}
}
void build(){
for(rg int i = 1;i <= cnt;i++)addedge(fail[i][0],i);
for(rg int j = 1;j <= 20;j++)
for(rg int i = 1;i <= cnt;i++)
if(fail[i][j-1] == -1)fail[i][j] = -1;
else fail[i][j] = fail[fail[i][j-1]][j-1];
dfs(0);
}
void prepro(){
for(rg int i = 1,u = 0,l = 0;i <= lt;i++){
int id = t[i] - 'a';
while(u && !nx[u][id])u = fail[u][0],l = len[u];
if(nx[u][id])u = nx[u][id],l++;
lcp[i] = l,loc[i] = u;
}
}
}S;
int lim[LT+5]; //lim[i]表示以i作为结尾的、最短的合法子串的左端点
struct SA{ //其实是前缀数组
int sa[LT+5],rk[LT+5],temp[LT+5],num[LT+5],h[LT+5],m;
void qsort(){
memset(num,0,sizeof(int) * (m+1)); //要是sizeof(num)就爆掉了,后面的几个也是类似
for(rg int i = 1;i <= lt;i++)num[rk[i]]++;
for(rg int i = 2;i <= m;i++)num[i] += num[i-1];
for(rg int i = lt;i >= 1;i--)sa[num[rk[temp[i]]]--] = temp[i];
}
void calch(){
int k = 0;
for(rg int i = lt;i >= 1;i--){
if(rk[i] == 1)h[rk[i]] = k = 0;
else{
if(k)k--;
int j = sa[rk[i]-1];
while(t[i-k] == t[j-k])k++;
h[rk[i]] = k;
}
}
}
void init(){
for(rg int i = 1;i <= lt;i++)rk[i] = t[i] - 'a' + 1;
for(rg int i = 1;i <= lt;i++)temp[i] = i;
m = 26;
qsort();
for(rg int n = 1;n <= lt;n <<= 1){
int cnt = 0;
for(rg int i = 1;i <= n;i++)temp[++cnt] = i;
for(rg int i = 1;i <= lt;i++)if(sa[i] + n <= lt)temp[++cnt] = sa[i] + n;
qsort();
memcpy(temp,rk,sizeof(int) * (lt+1));
cnt = rk[sa[1]] = 1;
for(rg int i = 2;i <= lt;i++){
if(temp[sa[i-1]] != temp[sa[i]] || temp[sa[i-1]-n] != temp[sa[i]-n])cnt++;
rk[sa[i]] = cnt;
}
if(cnt == lt)break;
m = cnt;
}
calch();
}
In bool legal(int u,int len){
if(!len)return 0;
if(len > r - l + 1)return 1;
return !T.sum(dfn[u],dfn[u] + sz[u] - 1,l + len - 1,r);
}
int calclim(int curr){
int u = loc[curr],len = lcp[curr];
if(!len)return curr;
if(!legal(u,len))return curr - len;
for(rg int i = 20;i >= 0;i--){
if(S.fail[u][i] == -1)continue;
if(legal(S.fail[u][i],S.len[S.fail[u][i]]))u = S.fail[u][i],len = S.len[u];
}
int L = S.len[S.fail[u][0]] + 1,R = len;
while(L < R){
int mid = (L + R) >> 1;
if(legal(u,mid))R = mid;
else L = mid + 1;
}
return curr - L + 1;
}
void count(){
for(rg int i = 1;i <= lt;i++)lim[i] = calclim(i);
ll ans = 0;
for(rg int i = 1;i <= lt;i++){
int curr = sa[i];
ans += (ll)min(lim[curr],curr - h[i]);
}
write(ans);putchar('
');
memset(rk,0,sizeof(int) * (lt+1));
memset(temp,0,sizeof(int) * (lt+1));
memset(sa,0,sizeof(int) * (lt+1));
}
}A;
int main(){
scanf("%s",s + 1);
ls = strlen(s + 1);
S.init();
for(rg int i = 1;i <= ls;i++)S.extend(s[i],i);
S.build();
int q = read();
while(q--){
scanf("%s",t + 1);
lt = strlen(t + 1);
l = read(),r = read();
S.prepro();
A.init();
A.count();
}
return 0;
}