BZOJ5304: [Haoi2018]字串覆盖
https://lydsy.com/JudgeOnline/problem.php?id=5304
分析:
- 设(L=r-l+1)。
- 建出(sam),倍增+线段树合并求出每个询问对应原串的(right)集合。
- 可以知道
- 如果(L>50),则每次在线段树上二分找到第一个(1),最多找(frac{n}{L})次。
- 否则就比较麻烦了,我们对于每个位置,维护(F[L][i])表示长度为(L),最后一个字符是(i)的子串向后找能找到谁,然后倍增求这个就完事了。
- 由于空间可能开不下,我一开始的做法是对每个状态开(vector)来减少存储状态,不过还是会(mle),改成将询问离线,那么(L)这一维就可以一起处理了。
- 其中求(F)我使用了字符串哈希和(map)。
- 如果哈希对(unsigned long long)自然溢出并且(base=998244353)会只有(40)分,别问我是怎么知道的。
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <vector>
#include <tr1/unordered_map>
using namespace std;
using namespace std::tr1;
#define N 200050
#define M 4000050
#define base 131
#define db(x) cerr<<#x<<" = "<<x<<endl
typedef long long ll;
typedef unsigned long long ull;
char ss[N],tt[N];
int n,K,ch[N][26],fa[N],lst=1,cnt=1,len[N];
int ls[M],rs[M],tot,siz[M],f[20][N],ke[N],ro[N],root[N];
int tl[N],tq[N],Lg[N];
ull h[N],mi[N];
vector<int>F[N>>1],G[N>>1];
vector<ll>H[N>>1];
int pl[N];
ll ans[N];
ull gh(int l,int r) {return h[r]-h[l-1]*mi[r-l+1];}
unordered_map<ull,int>mp;
void update(int l,int r,int x,int &p) {
if(!p) p=++tot;
siz[p]++;
if(l==r) return ;
int mid=(l+r)>>1;
if(x<=mid) update(l,mid,x,ls[p]);
else update(mid+1,r,x,rs[p]);
}
int merge(int x,int y) {
if(!x||!y) return x+y;
int p=++tot;
ls[p]=merge(ls[x],ls[y]);
rs[p]=merge(rs[x],rs[y]);
siz[p]=siz[ls[p]]+siz[rs[p]];
return p;
}
void insert(int x) {
int p=lst,np=++cnt,q,nq;
lst=np; len[np]=len[p]+1;
for(;p&&!ch[p][x];p=fa[p]) ch[p][x]=np;
if(!p) fa[np]=1;
else {
q=ch[p][x];
if(len[q]==len[p]+1) fa[np]=q;
else {
nq=++cnt;
len[nq]=len[p]+1;
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
for(;p&&ch[p][x]==q;p=fa[p]) ch[p][x]=nq;
}
}
}
int Ql(int l,int r,int p) {
if(l==r) return l;
int mid=(l+r)>>1;
if(siz[ls[p]]) return Ql(l,mid,ls[p]);
else return Ql(mid+1,r,rs[p]);
}
int query(int l,int r,int x,int y,int p) {
if(!p||!siz[p]) return -1;
if(x<=l&&y>=r) {
return Ql(l,r,p);
}
int mid=(l+r)>>1;
if(y<=mid) return query(l,mid,x,y,ls[p]);
else if(x>mid) return query(mid+1,r,x,y,rs[p]);
else {
int q=query(l,mid,x,y,ls[p]);
if(q!=-1) return q;
return query(mid+1,r,x,y,rs[p]);
}
}
ll solve(int l,int r,int p,int L) {
ll c1=0,c2=0;
if(L<=50) {
int x=query(1,n,l,r,root[p]);
if(x==-1) return 0;
int i;
for(i=18;i>=0;i--) {
int lim=F[x].size();
if(i>=lim) continue;
if(F[x][i]&&F[x][i]<=r) {
c1+=G[x][i];
c2+=H[x][i];
x=F[x][i];
}
}
c1++; c2+=x;
return c1*K-c2+(L-1)*c1;
}
while(l<=r) {
int x=query(1,n,l,r,root[p]);
if(x==-1) break;
c1++; c2+=x;
l=x+L;
}
return c1*K-c2+(L-1)*c1;
}
struct A {
int x,y,l,r,L,id;
bool operator < (const A &u) const {return L<u.L;}
}qq[N];
int main() {
scanf("%d%d%s%s",&n,&K,ss+1,tt+1);
int i,j,p;
for(i=1;i<=n;i++) insert(ss[i]-'a'),update(1,n,i,root[lst]);
for(i=1;i<=cnt;i++) f[0][i]=fa[i];
for(i=1;(1<<i)<=cnt;i++) for(j=1;j<=cnt;j++) f[i][j]=f[i-1][f[i-1][j]];
for(i=1;i<=cnt;i++) ke[len[i]]++;
for(i=1;i<=cnt;i++) ke[i]+=ke[i-1];
for(i=cnt;i;i--) ro[ke[len[i]]--]=i;
for(i=cnt;i>1;i--) {
p=ro[i]; root[fa[p]]=merge(root[fa[p]],root[p]);
}
p=1; int now=0;
for(i=1;i<=n;i++) {
int x=tt[i]-'a';
if(ch[p][x]) {
now++; p=ch[p][x];
}else {
for(;p&&!ch[p][x];p=fa[p]) ;
if(!p) now=0,p=1;
else now=len[p]+1,p=ch[p][x];
}
tl[i]=now; tq[i]=p;
}
int L;
for(mi[0]=i=1;i<=n;i++) mi[i]=mi[i-1]*base,h[i]=h[i-1]*base+ss[i];
for(Lg[0]=-1,i=1;i<=n;i++) Lg[i]=Lg[i>>1]+1;
int cas;scanf("%d",&cas);
int x,y,l,r;
for(i=1;i<=cas;i++) scanf("%d%d%d%d",&qq[i].x,&qq[i].y,&qq[i].l,&qq[i].r),qq[i].id=i,qq[i].L=qq[i].r-qq[i].l+1;
sort(qq+1,qq+cas+1);
int lf=1;
for(L=1;L<=50;L++) {
if(qq[lf].L!=L) continue;
for(i=n;i>=L;i--) {
ull tmp=gh(i-L+1,i);
if(mp[tmp]) {
int x=mp[tmp];
int sz=Lg[pl[x]+1];
F[i].resize(sz+1);
G[i].resize(sz+1);
H[i].resize(sz+1);
F[i][0]=x;
G[i][0]=1;
H[i][0]=i;
for(j=1;j<=sz;j++) {
int t=F[i][j-1];
F[i][j]=F[t][j-1];
G[i][j]=G[i][j-1]+G[t][j-1];
H[i][j]=H[i][j-1]+H[t][j-1];
}
pl[i]=pl[x];
}else {
pl[i]=-1;
}
pl[i]++;
if(i+L-1<=n) {
mp[gh(i,i+L-1)]=i+L-1;
}
}
mp.clear();
memset(pl,0,sizeof(pl));
for(;lf<=n&&qq[lf].L==L;lf++) {
x=qq[lf].x;
y=qq[lf].y;
l=qq[lf].l;
r=qq[lf].r;
int id=qq[lf].id;
if(tl[r]<L) {ans[id]=0; continue;}
p=tq[r];
for(j=18;j>=0;j--) {
if(f[j][p]&&len[f[j][p]]>=L) p=f[j][p];
}
ans[id]=solve(x+L-1,y,p,L);
}
for(i=1;i<=n;i++) F[i].clear(),G[i].clear(),H[i].clear();
}
for(j=1;j<=cas;j++) if(qq[j].L>50) {
x=qq[j].x;
y=qq[j].y;
l=qq[j].l;
r=qq[j].r;
int id=qq[j].id;
L=r-l+1;
if(tl[r]<L) {
ans[id]=0; continue;
}
p=tq[r];
for(i=18;i>=0;i--) {
if(f[i][p]&&len[f[i][p]]>=L) p=f[i][p];
}
ans[id]=solve(x+L-1,y,p,L);
}
for(i=1;i<=cas;i++) printf("%lld
",ans[i]);
}