柱爷搞子串
题目连接:
(http://acm.uestc.edu.cn/#/problem/show/1485)
Description
柱爷有一个字符串S,对于其中的每一个不同子串(S^ast),柱爷都能O(1)
的得到这些字符串的所有匹配位置
即能知道所有的[L,R]
区间使得 (S[L,R]=S^ast),然后柱爷会把这些[L,R]区间的每个位置做上标记,如果最后这些标记位置形成了K个连通块,那么它对答案的贡献就是1
柱爷早就知道了答案,但他现在想问你知道吗?
Input
输入第一行一个字符串S,只有小写字母,保证(|S|≤10^5)
接下来一行一个K
,保证(1≤K≤|S|)
Output
输出一行表示答案
Sample Input
abaab
2
Sample Output
3
Hint
对于ababa的字串aba它对所有字符都打上了标记,所以只有1个联通块
题意
对于每一个不同字串,对他的每一个出现位置的每一个字符都打上标记,形成k个联通块答案加1。
题解:
换个角度,如果我们知道了sam后缀自动机里面每个节点的right集,对于每一个right集,我们就可以知道一个字串长度区间使这个集合形成k个联通块,再和这个节点的接受长度区间取个交就是这个节点的所有可行解。
串长度有1e5,对于全是a数据,他的right集和的个数达到1e10,显然暴力出right集是不行的。
那么我们可以用set的维护right集,用平衡树维护right的差分数组,维护出来后对于差分数组我们只要求第size-k的大小就可以算得贡献。最后用启发式合并一下,虽然总体是nloglog,但在后缀树上的启发式跑得很快。
如果是随机数据,那么直接用set爆出right集应该也是没有问题的。
代码
//#include <bits/stdc++.h>
#include <stdio.h>
#include <iostream>
#include <string.h>
#include <math.h>
#include <stdlib.h>
#include <limits.h>
#include <algorithm>
#include <queue>
#include <vector>
#include <set>
#include <map>
#include <stack>
#include <bitset>
#include <string>
#include <time.h>
using namespace std;
long double esp=1e-11;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define fi first
#define se second
#define all(a) (a).begin(),(a).end()
#define cle(a) while(!a.empty())a.pop()
#define mem(p,c) memset(p,c,sizeof(p))
#define mp(A, B) make_pair(A, B)
#define pb push_back
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
typedef long long int LL;
const long double PI = acos((long double)-1);
const LL INF=0x3f3f3f3fll;
const int MOD =1000000007ll;
const int maxn=100100;
int Cnt,rot[maxn<<1]; //可重复,s表示节点的值的个数
struct Treap{ //rand()可能问题
int l,r;
int fix,key,size,s;
}tr[maxn*50];
inline void updata(int a){tr[a].size=tr[a].s+tr[tr[a].l].size+tr[tr[a].r].size;}
inline int newtreap(int _val){++Cnt;tr[Cnt].l=tr[Cnt].r=0;tr[Cnt].fix=rand();tr[Cnt].key=_val;tr[Cnt].size=tr[Cnt].s=1;return Cnt;}
void init(){rot[0]=Cnt=0;tr[0].size=tr[0].s=tr[0].l=tr[0].r=0;}
int Merge(int A,int B){//合并操作
if(!A)return B;
if(!B)return A;
if(tr[A].fix<tr[B].fix){
tr[A].r=Merge(tr[A].r,B);
updata(A);
return A;
}else{
tr[B].l=Merge(A,tr[B].l);
updata(B);
return B;
}
}
pair<int,int> Split(int x,int k){//拆分操作前k
if(!x)return mp(0,0);
pair<int,int> y;
if(tr[tr[x].l].size>=k){
y=Split(tr[x].l,k);
tr[x].l=y.second;
updata(x);
y.second=x;
}else{
y=Split(tr[x].r,k-tr[tr[x].l].size-tr[x].s);
tr[x].r=y.first;
updata(x);
y.first=x;
}
return y;
}
int Findkth(int t,int k){//查找第K小
while(t)
{
if(tr[tr[t].l].size<k&&tr[tr[t].l].size+tr[t].s>=k)
return tr[t].key;
else if(tr[tr[t].l].size+tr[t].s<k)
{
k-=tr[tr[t].l].size+tr[t].s;
t=tr[t].r;
}
else
t=tr[t].l;
}
return 0;
}
int Getkth(int x,int v){//询问一个数v是第几大,最大可能
if(!x)return 0;
return v<tr[x].key?Getkth(tr[x].l,v):Getkth(tr[x].r,v)+tr[tr[x].l].size+tr[x].s;
}
void Insert(int v,int t){//插入操作
if(!rot[t]){rot[t]=newtreap(v);return;}
pair<int,int> x=Split(rot[t],Getkth(rot[t],v-1));
pair<int,int> y=Split(x.second,Getkth(x.second,v));
if(y.first==0)y.first=newtreap(v);
else tr[y.first].s++,tr[y.first].size++;
rot[t]=Merge(Merge(x.first,y.first),y.second);
}
void Delete(int k,int t){//删除操作 1个k
pair<int,int> x=Split(rot[t],Getkth(rot[t],k-1));
pair<int,int> y=Split(x.second,Getkth(x.second,k));
if(tr[y.first].s<=1)y.first=0;
else tr[y.first].s--,tr[y.first].size--;
rot[t]=Merge(Merge(x.first,y.first),y.second);
}
void pr(int t)
{
if(tr[t].l)pr(tr[t].l);
printf("%d*%d ",tr[t].s,tr[t].key);
if(tr[t].r)pr(tr[t].r);
}
#define Len 101000
#define Alp 26
int pa[Len<<1],son[Len<<1][Alp],Right[Len<<1];
int Max[Len<<1],cnt,root,last; //0为null节点,null只能到null
inline int Newnode(int _Max){++cnt;memset(son[cnt],0,sizeof(son[cnt]));Max[cnt]=_Max;return cnt;}
inline void pre(){cnt=Max[0]=0;root=last=Newnode(0);}
inline void SAM(int alp,int t) //注意T=26,s[x]-'a',多串时每次last=root
{
int np=son[last][alp],u=last,v,nv;
if(np&&Max[np]==Max[last]+1){last=np;return;}//已有状态,对所有父状态更新
else np=Newnode(Max[last]+1);
Right[np]=t;
while(u&&!son[u][alp])son[u][alp]=np,u=pa[u];
if(!u)pa[np]=root;
else
{
v=son[u][alp];
if(Max[v]==Max[u]+1)pa[np]=v;
else
{
nv=Newnode(Max[u]+1); Right[nv]=0;
memcpy(son[nv],son[v],sizeof(son[v]));
pa[nv]=pa[v],pa[v]=pa[np]=nv;
while(u&&son[u][alp]==v)son[u][alp]=nv,u=pa[u];
}
}
last=np;
}
char s[maxn];
int k;
LL ans=0;
vector<int>node[maxn<<1];
set<int>q[maxn<<1];
set<int>::iterator iter,it;
int pos[maxn<<1],cc=0;
void ins(int val,int t)
{
iter=q[pos[t]].lower_bound(val);
if(iter!=q[pos[t]].end())
{
it=iter;
it--;
Insert((*iter)-val,t);
Insert(val-(*it),t); //printf("ins %d %d
",(*iter)-val,val-(*it));
Delete((*iter)-(*it),t); //printf("del %d
",(*iter)-(*it));
}
else
{
iter--;
Insert(val-(*iter),t); //printf("ins %d
",val-(*iter));
}
q[pos[t]].insert(val);
}
void dfs(int u)
{ //if(u>10000)printf("%d %d
",u,Cnt);
int ma=0,mapos=0;
for(int x:node[u])
{
dfs(x);
if(tr[rot[x]].size>ma)ma=tr[rot[x]].size,mapos=x;
} //printf("
node%d %d
",u,mapos);
rot[u]=rot[mapos];
if(mapos)pos[u]=pos[mapos];
else pos[u]=cc++,q[pos[u]].clear(),q[pos[u]].insert(-maxn*2);
if(Right[u])ins(Right[u],u);
for(int x:node[u])
if(x!=mapos)
{
for(int val:q[pos[x]])
if(val!=-maxn*2)
ins(val,u);
q[pos[x]].clear();
}
if(tr[rot[u]].size>=k)
{
int v2=Findkth(rot[u],tr[rot[u]].size-k+1);
int sz=Getkth(rot[u],v2-1);
if(sz==tr[rot[u]].size-k)
{
pair<int,int> x=Split(rot[u],Getkth(rot[u],v2-1));
int v1=Findkth(x.first,tr[x.first].size);
rot[u]=Merge(x.first,x.second);
ans+=max(0,min(v2-1,Max[u])-max(v1,Max[pa[u]]+1)+1);
// printf("ans sz=%d %d %d %d %d get %d
",tr[rot[u]].size,v1,v2,Max[pa[u]]+1,Max[u],max(0,min(v2-1,Max[u])-max(v1,Max[pa[u]]+1)+1));
}
}
// printf("%d %d %d
",u,pa[u],Right[u]);
// for(int val:q[pos[u]])printf("%d ",val);puts("");
// pr(rot[u]);puts("");
//答案
}
void test()
{
ans=0;
cc=0;
pos[0]=0;
q[0].clear();
int n=strlen(s);
pre();
for(int x=0;x<n;x++)
SAM(s[x]-'a',x+1);
// for(int x=0;x<=cnt;x++)
// {
// printf("%d %d %d son=
",x,pa[x],Max[x]);
// for(int y=0;y<26;y++)
// if(son[x][y])
// printf("%c %d ",y+'a',son[x][y]);
// putchar('
');
// }
for(int x=0;x<=cnt;x++)
node[x].clear();
init();
for(int x=1;x<=cnt;x++)
if(pa[x])
node[pa[x]].pb(x);
dfs(1);
}
int main()
{
//freopen("in.txt", "r", stdin);
//freopen("avl.in", "r", stdin);
//freopen("avl.out", "w", stdout);
//::iterator iter; %I64d
//for(int x=1;x<=n;x++)
//for(int y=1;y<=n;y++)
//scanf("%d",&a);
//printf("%d
",ans);
// k=1;
// for(int x=0;x<=100000;x++)s[x]='a';
// s[100001]=0;
scanf("%s%d",s,&k);
test();
printf("%lld
",ans);
return 0;
}