[BZOJ 3277]字符串(后缀自动机)
题面
给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?
分析
首先,我们把所有字符串建成一个广义SAM.(实际上,只需要插入完每个字符串之后吧last设回根节点)
然后对于每个字符串,在自动机上跑,对于跑到的每个节点,沿着parent树往上跳.这相当于枚举每个前缀的所有后缀,其实就是所有子串。这样就可以标记出每个节点代表的这些子串在多少个不同的字符串里出现过,记为(cov(x)).同时为了不重复枚举,我们需要标记每个节点是否已经被这个字符串的其他子串访问过了,这样只需跳到最上面的没有被这个字符串访问过的节点.可以证明复杂度是(O(n sqrt n))的
容易发现节点x产生的不同子串有(sum(x)=len(x)-len(link(x))).再对sum求前缀和,DFS一遍parent树.
那么最终答案就是(sum_x [cov(x) geq k] imes sum(x))
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#define maxn 200000
#define maxc 26
using namespace std;
typedef long long ll;
int n,K;
char in[maxn+5];
string s[maxn+5];
struct SAM{
#define link(x) (t[x].link)
#define len(x) (t[x].len)
struct node{
int ch[maxc];
int len;
int link;
int cov;//这个点代表的子串集合被多少个原串包含
ll sum;//沿link走到根的路径上有多少个子串至少出现K次
}t[maxn*2+5];
const int root=1;
int ptr=1;
int last=root;
void extend(char ch){
int c=ch-'a';
int p=last,cur=last=++ptr;
len(cur)=len(p)+1;
while(p&&t[p].ch[c]==0){
t[p].ch[c]=cur;
p=link(p);
}
if(p==0) link(cur)=root;
else{
int q=t[p].ch[c];
if(len(p)+1==len(q)) link(cur)=q;
else{
int clo=++ptr;
t[clo]=t[q];
link(q)=link(cur)=clo;
len(clo)=len(p)+1;
while(p&&t[p].ch[c]==q){
t[p].ch[c]=clo;
p=t[p].link;
}
}
}
last=cur;
}
void insert(string &s){
last=root;
for(int i=0;i<s.length();i++) extend(s[i]);
}
void get_cov(){
static int last[maxn+5];
for(int i=1;i<=n;i++){
int x=root;
for(int j=0;j<s[i].length();j++){
x=t[x].ch[s[i][j]-'a'];
for(int y=x;y!=root&&last[y]!=i;y=link(y)){//暴力跳parent树标记,当last[y]=i时就停止,可以证明是O(nsqrt(n))的
last[y]=i;
t[y].cov++;
}
}
}
}
bool vis[maxn+5];
void dfs(int x){
if(x==root||vis[x]) return;
vis[x]=1;
dfs(link(x));
t[x].sum+=t[link(x)].sum;
}
void get_ans(){
for(int i=1;i<=ptr;i++){
t[i].sum=(t[i].cov>=K)*(len(i)-len(link(i)));
//x的节点表示的本质不同的子串有len(x)-len(link(x))个
}
for(int i=1;i<=ptr;i++) dfs(i);
for(int i=1;i<=n;i++){
ll ans=0;
int x=root;
for(int j=0;j<s[i].length();j++){
x=t[x].ch[s[i][j]-'a'];
ans+=t[x].sum;
}
printf("%lld ",ans);
}
}
}T;
int main(){
scanf("%d %d",&n,&K);
for(int i=1;i<=n;i++){
scanf("%s",in);
s[i]=string(in);
T.insert(s[i]);
}
T.get_cov();
T.get_ans();
// printf("%lld
",T.get_ans());
}