题意简述
给定长度为 (n) 的字符串 (S,T) ,求有多少个不同的 (T) 的子串 (t) ,满足 (t) 是 (S) 的一个子序列。
(1le nle 3000)
算法分析
子串的个数是 (mathcal{O}(n^2)) 的,子序列的个数是 (mathcal{O}(2^n)) 的,因此考虑枚举所有子串,判断是否是 (S) 的子序列。
如何快速判断一个字符串是母串的子序列?直接上子序列自动机就好了。由于枚举过程是增量枚举的,因此总复杂度为 (mathcal{O}(n^2log n)) 或者 (mathcal{O}(n^2+n|Sigma|)) 的,取决于子序列自动机的实现方法。
但是我们枚举的子串可能有相同的,需要去重,hash即可,因为字符串总量比较大,用 双模hash
比较保险。
熟悉子序列自动机的可以跳过下面一段:
子算法1 子序列自动机
由名称,不难得出其用途。子序列自动机可以判断一个串是否是母串的子序列。
下设询问串为 (P) ,母串为 (S) 。
考虑这个询问串在母串上匹配的过程,假设当前询问串的前 (i) 位都是母串的子序列,且在母串中匹配到 (cur) 。形式化的讲, (P[1:i]) 是 (S) 的子序列,且 (P[i]=S[cur]) 。
现在我们要匹配 (P[i+1]) ,如果能匹配上,那么 (S) 串在 (cur) 位置后一定存在一个位置 (k) 能匹配上,即 (exists k>cur , P[i+1]=S[k]) 。
但是 (S) 串后面可能有若干个合法的 (k) ,我们应该取哪一个呢?
我们应该取最靠前的那一个,即 (k>cur , forall jin (cur,k] , P[i+1] eq S[j]) 。
为什么这样的贪心是正确的?因为这个过程有决策包容性。即我们取最靠前的符合要求的 (k) ,不会使得答案变差。
后面黄色框表示如果选择 (k_2) , (P) 串后面可能的一种子序列匹配,在我们选择 (k_1) 的时候这种后面的匹配仍然是可达的,因此不会丢失答案。
接下来有两种实现,根据不同情况应选择不同实现方法:
- 记
nxt[i][c]
表示位置 (i) 之后第一个为 (c) 的字符,记录一个lst[c]
表示当前范围内 (c) 最后一次的出现位置,倒序扫描一遍即可。构建时空复杂度为 (mathcal{O}(n|Sigma|)),查询时间复杂度为 (mathcal{O}(|P|))。 - 开 (|Sigma|) 个
vector
,存储每一种字符的出现位置,查询的时候二分位置即可,构造时空复杂度为 (mathcal{O}(n)) ,查询时间复杂度为 (mathcal{O}(|P|log n))。
一般来说,对于字符集较小,查询量较大的题目,推荐使用第一种写法。对于字符集较大,或者空间较为紧张的题目,推荐使用第二种写法。
实现方法1:
int nxt[maxn][26];//假定为字符集为所有小写字符
int lst[26];
int n;
void build(char *S){
n=strlen(S+1);
for(int j=0;j<26;++j)lst[j]=n+1;
for(int i=n;i>=0;--i){
for(int j=0;j<26;++j)nxt[i][j]=lst[j];
lst[S[i]-'a']=i;
}
}
bool query(char *P){
int cur=0,np=strlen(P+1);
for(int i=1;i<=np;++i){
cur=nxt[cur][P[i]-'a'];
if(cur>n)return 0;
}
return 1;
}
实现方法2:
int n;
vector<int>ps[26];
void build(char *S){
n=strlen(S+1);
for(int i=1;i<=n;++i)ps[S[i]-'a'].push_back(i);
for(int j=0;j<26;++j)ps[j].push_back(n+1);//防止越界,便于处理
}
bool query(char *P){
int cur=0,np=strlen(P+1);
for(int i=1;i<=np;++i){
int nxt=*upper_bound(ps[P[i]-'a'].begin(),ps[P[i]-'a'].end(),cur);
if(nxt>n)return 0;
cur=nxt;
}
return 1;
}
能够正确写出双模HASH的可以跳过下面一段:
子算法2 HASH
可能有很多同学在初学字符串 HASH
的时候写的 HASH
是假的(错误率很高)(包括我自己)。
字符串 HASH
核心思想是把字符串看作一个 BAS
进制数,因为显然存不下,考虑取模,比较常用的 BAS
=(131,13331),常用的取模是unsigned long long
自然溢出。
第一个要注意的地方是模数要足够大。由生日悖论, (sqrt n) 个值域为 ([0,n)) 的数存在相同数的概率超过 (50\%) ,如果模数是 int
范围的,则长度为 (10^5) 左右的随机字符串已经很容易产生冲突。可参见 Hash Killer II 。
但是我们仅使用自然溢出也会出问题,因为有对着卡的方法,参见 Hash Killer I 。
因此我们通过双底数/双模数的方法处理,具体的,我们取两个不同的BAS
和Mod
,分别计算 HASH
,两个 HASH
都相同才认为是相同的的。
这种方法目前似乎没有很好的方法卡掉,具体可参见 Hash Killer Ⅲ。
处理完错误率的问题,下面来处理效率问题。
先是构建的过程,考虑定义式(可能有多种定义,仅举一例):
为了方便后面计算,还应记录前缀和,即:
暴力计算是 (mathcal{O}(nlog n)) 的,这个过程可以用秦九韶算法优化:
这样避免了快速幂,时间复杂度为 (mathcal{O}(n)) 。
接下来是查询过程。我们查询子串 (S[l:r]) ,则对应答案为:
暴力实现是 (mathcal{O}(log n)) 的,我们预处理出所有的 (BAS^k) ,这样复杂度降为 (mathcal{O}(1)) 。
现在的复杂度是线性的,接下来是一些常数优化和一些细节:
- 使用
unsigned long long
自然溢出,减少取模 - 如果是两个数相加/相减,且能保证都在 ([0,Mod)) 范围内,可以使用减法代替取模
- 底数和模数不能过大,应保证 (max{BAS,Mod} imes Mod< 2^{62}),否则在乘法过程中可能会超出
long long/unsigned long long
范围 - 推荐使用
unsigned long long
而不是long long
。尤其注意 自然溢出不能使用long long
,因为long long
的溢出是UB
!
有关本题的一个细节:
由于只有一个询问,去重应使用 sort+unique
或手写哈希表,map/unordered_map
常数巨大,通过此题比较困难。
代码实现
我的代码里采用的是第二种子序列自动机实现方法。
有关 HASH
,我的代码没有完全做到上面的优化,且第二个模数是 int
范围的,有一定的优化空间。
#include<bits/stdc++.h>
using namespace std;
#define maxn 1000005
#define maxm 2000005
#define inf 0x3f3f3f3f
#define LL long long
#define ull unsigned long long
#define db double
#define ldb long double
#define mod 1000000007
#define eps 1e-9
#define local
void file(string s){freopen((s+".in").c_str(),"r",stdin);freopen((s+".out").c_str(),"w",stdout);}
template <typename Tp> void read(Tp &x){
int fh=1;char c=getchar();x=0;
while(c>'9'||c<'0'){if(c=='-'){fh=-1;}c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c&15);c=getchar();}x*=fh;
}
int n,m;
char S[maxn],T[maxn];
vector<int>ps[26];
struct HS_node{
ull hs1,hs2;
HS_node operator +(HS_node y)const{
return (HS_node){hs1+y.hs1,(hs2+y.hs2)%mod};
}
HS_node operator -(HS_node y)const{
return (HS_node){hs1-y.hs1,(hs2-y.hs2+mod)%mod};
}
HS_node operator *(HS_node y)const{
return (HS_node){hs1*y.hs1,(hs2*y.hs2)%mod};
}
bool operator <(HS_node y)const{
return hs1==y.hs1?hs2<y.hs2:hs1<y.hs1;
}
bool operator ==(HS_node y)const{
return hs1==y.hs1&&hs2==y.hs2;
}
};
struct MY_Hash{
const ull Bas1=131,Bas2=13331;
HS_node pw[maxn],sh[maxn];
void build(const char *str){//构建hash
int nn=strlen(str+1);
pw[0]=(HS_node){1,1};
for(int i=1;i<=nn;++i)pw[i]=pw[i-1]*(HS_node){Bas1,Bas2};
for(int i=1;i<=nn;++i)sh[i]=sh[i-1]*(HS_node){Bas1,Bas2}+(HS_node){str[i],str[i]};
}
HS_node get_hash(int l,int r){
return sh[r]-sh[l-1]*pw[r-l+1];
}
}hh;
HS_node aa[9000005];
int ans;
signed main(){
#ifndef local
file("block");
#endif
read(n);
scanf("%s",S+1);
scanf("%s",T+1);
hh.build(T);
for(int i=1;i<=n;++i)ps[S[i]-'a'].push_back(i);//子序列自动机构建
for(int i=0;i<26;++i)ps[i].push_back(n+1);//防止超出边界,push一个终止符
for(int i=1;i<=n;++i){
int cur=0;
for(int j=i;j<=n;++j){
int nxt=*upper_bound(ps[T[j]-'a'].begin(),ps[T[j]-'a'].end(),cur);//子序列自动机的转移
if(nxt>n)break;
aa[++m]=hh.get_hash(i,j);
cur=nxt;
}
}
sort(aa+1,aa+m+1);
ans=unique(aa+1,aa+m+1)-aa-1;//去重
printf("%d
",ans);
fclose(stdin);
fclose(stdout);
return 0;
}