Description
给定字符串 (S),你可以执行一种操作,将 (S) 拼接到当前字符串后,可以重叠,但重叠的位置必须对位相等,要求最终字符串长度不超过 (m),现在请你计算最终得到字符串的所有可能长度数量。(|S|le 5 imes 10^5,mle 10^{18})
Solution
首先重叠的部分一定是 (s) 的一个 (border),因此每次操作就等于是给原串增加 (|s|-) 原串的一个 (border) 的长度。先 (kmp) 求出所有的 (|s|-border) 长度,设为 ({x_i}),那么问题转化为求解 (sum_{i=1}^{sum}a_ix_i) 在 ([0,m-|s|]) 范围内能取到多少个不同的值。
容易看出,这是同余最短路的形式,用同余最短路解决这样的问题,就是将 (x) 从小到大排序,以 (x_1) 作为模数,对 (0sim x_1-1)分别建点,对 (forall iin [2,sum],jin [0,x_1-1]) ,从 (j) 向 ((x_i+j)mod x_1) 建一条边权为 (x_i) 的边,然后跑最短路,(dis_i) 就表示用 (x_2dots x_{sum}) 能表示出的最小的 (mod x_1=i) 的数,那么(le dis_i) 的所有 (mod x_1=i) 的数都能被表示出来,就可以求出答案了。
暴力做同余最短路,复杂度是 (mathcal O(n^2log n)),考虑优化。
首先引入字符串的一个重要性质:对于任意一个字符串,它的 (border) 长度一定会构成 (mathcal O(log n))个等差数列。
由于博主太菜,在此就不给出证明了。
现在我们假设 ({x_i}) 是一个首项为 (x) ,公差为 (d) ,长度为 (len) 的等差数列,考虑求出其答案。
按同余最短路的方式建边,最终一定会连出 (gcd(d,x)) 个环,每个环互相独立。对于任意一个环,环上 (dis) 最小的点一定不会被其他点松弛,因此以它作为起点,依次考虑环上每个点,会松弛到它的点与它的距离应当不超过 (len-1),因此如果点 (j) 比点 (i) 距离接下来的点更近,且 (dis_j>dis_i+w(i,j)),那么 (i) 的松弛一定劣于 (j),可以弹出 (j),因此可以使用单调队列维护,至此我们就完成了 (mathcal O(n)) 求出这种情况下的答案。
回到原问题,现在我们能够完成每个等差数列内部的转移了,还有考虑前后两个等差数列的组合,设前一个的首项为 (u),后一个的首项为 (v),现在要从 (mod u) 的答案转移到 (mod v) 的答案上,首先可以将 (dis_x) 转移到新的 (dis_{dis_xmod v}) 上,但同时我们还要考虑长为 (u) 的转移,我们可以将其看作首项为 (0),公差为 (u),长度为 (1) 的等差数列,用上面的方法完成即可。
总复杂度 (mathcal O(n log n))
Code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e6+10;
int n,fail[N];
char s[N];
int stk[N],top;
inline void getborder(){
fail[1]=0;
for(int i=2;i<=n;++i){
for(int j=fail[i-1];;j=fail[j]){
if(s[j+1]==s[i]&&j+1!=i){fail[i]=j+1;break;}
if(j==0){fail[i]=0;break;}
}
}
int now=fail[n];top=0;
while(now){
stk[++top]=n-now;
now=fail[now];
}stk[++top]=n;
}
ll m,dp[N],tmp[N],q[N][2];
int l,r;
inline void trans(int mod,int x,int del,int mxtrans){
int g=__gcd(del,mod);
for(int T=0;T<g;++T){
ll mn=m+1,pos=-1;int cnt=0;
for(int j=T;;){
if(j==T) if((++cnt)>1) break;
if(dp[j]<mn) mn=dp[j],pos=j;
j+=del;while(j>=mod) j-=mod;
}
cnt=0;
if(pos==-1) continue;
if(mxtrans==1){
for(int i=1,j=(pos+del)%mod,last=pos;j!=pos;++i){
dp[j]=min(dp[j],dp[last]+del+x);
last=j;j+=del;while(j>=mod) j-=mod;
}
continue;
}
l=1;r=0;
q[++r][0]=0;q[r][1]=mn;
for(int i=1,j=(pos+del)%mod;j!=pos;++i){
while(l<=r&&i-q[l][0]>mxtrans) ++l;
dp[j]=min(dp[j],q[l][1]+1ll*(i-q[l][0])*del+x);
while(l<=r&&q[r][1]+1ll*(i-q[r][0])*del>dp[j]) --r;
q[++r][0]=i;q[r][1]=dp[j];
j+=del;while(j>=mod) j-=mod;
}
}
}
int main(){
int tt;scanf("%d",&tt);
while(tt--){
scanf("%d%lld",&n,&m);
scanf("%s",s+1);
getborder();
for(int j=0;j<stk[1];++j) dp[j]=m+1;
dp[n%stk[1]]=n;
int res=1;
ll ans=0;
for(int i=1,j=0;i<top;i=j+1){
int del=stk[i+1]-stk[i];j=i+1;
while(j<=top&&stk[j]-stk[j-1]==del) ++j;--j;
if(j>i) trans(stk[i],stk[i],del,j-i),res=i;
if(j<top){
for(int k=0;k<stk[j+1];++k) tmp[k]=m+1;
for(int k=0;k<stk[i];++k){
ll x=dp[k]%stk[j+1];
tmp[x]=min(tmp[x],dp[k]);
}
memcpy(dp,tmp,sizeof(ll)*(stk[j+1]));
trans(stk[j+1],0,stk[i],1);res=j+1;
}
}
for(int i=0;i<stk[res];++i)
if(dp[i]<=m) ans+=(m-dp[i])/stk[res]+1;
printf("%lld
",ans);
}
return 0;
}