我是用 广义SAM+线段树合并做的,好像大家都没仔细讲,于是我来讲一下。
首先先建出 SAM,可以找出 SAM 中每一个点是在第几个字符串中出现的,在一棵动态开点的线段树上记录一下。然后把 SAM dfs 一遍,每一个点在每一个字符串的贡献加上他的 (son) 的权值。
这里有一个问题:就是 SAM 中可能有些点代表的字符串是一样的,解决方法是把一样的点都交给最上面的处理。
dfs 完之后理论上可以得到每一个点每一个字符串的贡献,就可以做了。
但是他有很多细节。
首先在建 SAM 的时候就把每一个点的贡献预处理好,代码如下。
for(int i=1;i<=n;i++)
{
scanf("%s",str); int len=strlen(str);
int las=1;
for(int j=0;j<len;j++) las=insert(str[j]-'a',las),update(rt[las],1,n,i);;
pos[i]=las;
}
然后因为我懒得离线(不太会),用了一个在线做法,线段树合并之后的两个点都是有用的,所以以下的合并做法会锅,他会在后面更新的时候覆盖掉前面还要用的数据。
int merge(int u,int v)
{
if(!u||!v) return u|v;
t[u].v+=t[v].v;
t[u].ls=merge(t[u].ls,t[v].ls);
t[u].rs=merge(t[u].rs,t[v].rs);
return u;
}
每次合并的时候要新开一个点,然后合并完了把原来的并回去,代码如下。
int merge(int u,int v)
{
if(!u&&!v) return 0;
if(!v) return u;
if(!u) {int x=++_cnt; t[x]=t[v]; return x;}
int x=++_cnt;
t[x].v=t[u].v+t[v].v;
t[x].ls=merge(t[u].ls,t[v].ls);
t[x].rs=merge(t[u].rs,t[v].rs);
return x;
}
因为我很菜,所以这个做法还有点卡空间,我用了 253MB 的空间,距离 MLE 还有 3MB 。
于是经过一些不怎么简单的实现,(其实是我太菜了,调了一个晚上),我们得到了一个时间 (O(|s|log |s|+q)),空间 (O(|s|log |s|)) 的做法。
完整代码如下:(并不怎么好看)
#define N 400005
int len[N],fa[N],ch[N][26],rt[N],pos[N],f[N];
char str[N];
int cnt=1,n,Q;
int insert(int c,int las)
{
int p=las,np=++cnt;
len[np]=len[p]+1;
for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=np;
if(!p) fa[np]=1;
else
{
int q=ch[p][c];
if(len[p]+1==len[q]) fa[np]=q;
else
{
int nq=++cnt; len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],26*4);
fa[nq]=fa[q]; fa[q]=fa[np]=nq;
for(;ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
}
}
return np;
}
struct Node{int ls,rs,v;};
struct Query{int l,r,id;};
int ans[N];
Node t[N*40];
vector<int> G[N];
vector<Query> q[N];
int _cnt;
void update(int &u,int l,int r,int p)
{
if(!u) u=++_cnt;
if(l==r) {t[u].v++; return ;}
int mid=(l+r)/2;
if(p<=mid) update(t[u].ls,l,mid,p);
else update(t[u].rs,mid+1,r,p);
t[u].v=t[t[u].ls].v+t[t[u].rs].v;
}
int query(int u,int l,int r,int L,int R)
{
if(!u) return 0;
if(L<=l&&r<=R) return t[u].v;
int mid=(l+r)/2,ans=0;
if(L<=mid) ans+=query(t[u].ls,l,mid,L,R);
if(R>mid) ans+=query(t[u].rs,mid+1,r,L,R);
return ans;
}
int merge(int u,int v)
{
if(!u&&!v) return 0;
if(!v) return u;
if(!u)
{
int x=++_cnt;
t[x]=t[v];
return x;
}
int x=++_cnt;
t[x].v=t[u].v+t[v].v;
t[x].ls=merge(t[u].ls,t[v].ls);
t[x].rs=merge(t[u].rs,t[v].rs);
return x;
}
void dfs(int u)
{
for(int v:G[u])
{
if(len[u]==len[v]) f[v]=f[u];
else f[v]=v;
dfs(v); rt[u]=merge(rt[u],rt[v]);
}
}
signed main()
{
cin>>n>>Q;
for(int i=1;i<=n;i++)
{
scanf("%s",str); int len=strlen(str);
int las=1;
for(int j=0;j<len;j++) las=insert(str[j]-'a',las),update(rt[las],1,n,i);;
pos[i]=las;
}
for(int i=2;i<=cnt;i++) G[fa[i]].pb(i);
f[1]=1;
dfs(1);
while(Q--)
{
int u=read(),v=read(),w=read();
printf("%d
",query(rt[f[pos[w]]],1,n,u,v));
}
return 0;
}