题目:https://www.lydsy.com/JudgeOnline/problem.php?id=2119
就是找差分序列上中间差 m 的相等的两段。
考虑枚举这样一段的长度 L 。可以把序列分成 ( frac{n}{L} ) 段;令 L , 2L , ... 这样的位置为关键点,那么每个关键点 i 求一下 LCP( i , i+L+m ) 和 LCS( i , i+L+m ) ,就能知道过这个关键点的左端点的合法范围。用 lst 记录上一个关键点算出的右端点来去重即可。
这样是 nlogn 的,且不会遗漏合法的解。
如果有两个关键点之间的合法左端点,满足该点左边是不合法的左端点,右边也是不合法的左端点,那么这个解就不会被算上。但不会有这样的情况。因为以这个点为左端点的段的长度是 L ,一定跨越了下一个关键点;从下一个关键点找 LCP 一定会覆盖这个左端点。
预处理 LCP 和 LCS 可以把差分序列正着和反着接在一起,中间填一个没出现的最小字符,即 0 ;但平时求 ht[ ] 的时候没有判断 if(rk[i]==1)continue ; ,因为到时候会有 s[ 0+0 ] != s[ i+0 ] ,但此时会有 s[ 0 ] = 0 , s[ i ] = 0 ( rk[ i ] = 1 ) ,所以会求错。需要注意。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mn(int a,int b){return a<b?a:b;} int Mx(int a,int b){return a>b?a:b;} const int N=1e5+5,K=20; int n,m,s[N],sa[N],rk[N],tp[N],tx[N],ht[N][K],bin[K],lg[N]; ll ans; void Rsort(int n,int nm) { for(int i=1;i<=nm;i++)tx[i]=0; for(int i=1;i<=n;i++)tx[rk[i]]++; for(int i=2;i<=nm;i++)tx[i]+=tx[i-1]; for(int i=n;i;i--)sa[tx[rk[tp[i]]]--]=tp[i]; } void get_sa(int n) { int nm=n+1; for(int i=1;i<=n;i++)tp[i]=i,rk[i]=s[i]+1;//+1 for s[i]=0!!! Rsort(n,nm); for(int k=1;;k<<=1) { int tot=0; for(int i=n-k+1;i<=n;i++)tp[++tot]=i; for(int i=1;i<=n;i++) if(sa[i]>k)tp[++tot]=sa[i]-k; Rsort(n,nm);memcpy(tp,rk,sizeof rk); nm=1; rk[sa[1]]=1; for(int i=2;i<=n;i++) { int u=sa[i]+k,v=sa[i-1]+k;if(u>n)u=0;if(v>n)v=0; rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[u]==tp[v])?nm:++nm; } if(nm==n)break; } } void get_ht(int n) { for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1; bin[0]=1;for(int i=1;i<=lg[n];i++)bin[i]=bin[i-1]<<1; s[0]=N;///////for rk[i]==1 s[0+0]==s[i+0] for(int i=1,j,k=0;i<=n;i++)//k=0!! { for((k?k--:0),j=sa[rk[i]-1];i+k<=n&&j+k<=n&&s[i+k]==s[j+k];k++); ht[rk[i]][0]=k; } for(int t=1;t<=lg[n];t++) for(int i=1;i+bin[t]-1<=n;i++) ht[i][t]=Mn(ht[i][t-1],ht[i+bin[t-1]][t-1]); } int qry_ht(int l,int r,bool fx) { if(l==r)return fx?sa[l]-n:n-sa[l]+1; if(l>r)swap(l,r); int d=lg[r-l]; return Mn(ht[l+1][d],ht[r-bin[d]+1][d]);//l+1 } int main() { n=rdn();m=rdn();for(int i=1;i<=n;i++)s[i]=rdn(); n--;for(int i=1;i<=n;i++)s[i]=tp[i]=s[i+1]-s[i]; sort(tp+1,tp+n+1);int tmp=unique(tp+1,tp+n+1)-tp-1; for(int i=1;i<=n;i++)s[i]=lower_bound(tp+1,tp+tmp+1,s[i])-tp; s[n+1]=0; for(int i=n+2,j=n;j;i++,j--)s[i]=s[j]; int len=n*2+1; get_sa(len);get_ht(len); for(int L=1,lst=0;L<=n;L++,lst=0) //for(int i=L;i+L+m<=n;i+=L) for(int i=1;i+L+m<=n;i+=L) { int d=i+L+m; int l2=qry_ht(rk[i],rk[d],0); int l1=qry_ht(rk[len-i+1],rk[len-d+1],1); int st=Mx(lst+1,i-l1+1); int en=i+l2-L; if(en<st)continue; lst=en; ans+=en-st+1; } printf("%lld ",ans); return 0; }