题解
对询问串建立 $ ext{AC}$ 自动机,考虑建出 $ ext{fail}$ 树, $ ext{fail}$ 树上节点所代表的串是这个节点子树内每个点所代表的的串的后缀。所以我们可以把链分成两条,把正反串都放入 $ ext{AC}$ 自动机中,对于一条链 $(lca,u)$ ,对于不包含 $lca$ 的子串,我们可以用根到 $u$ 的答案减去根到包含 $lca$ 的子串的最上方的点的答案,那我们就可以记录一下询问串的结束节点, $ ext{dfs}$ 原树的时候也一起走 $ ext{AC}$ 自动机,进入的时候 $+1$ ,回溯的时候 $-1$ ,用树状数组维护区间和即可。然后如果子串包含了 $lca$ 的话发现这条路径上有效的点是 $2|s|$ 的,于是拉出来做 $ ext{kmp}$ 即可。
代码
#include <bits/stdc++.h> using namespace std;const int N=3e5+5; int n,m,id[N],sz[N],dp[N],fa[19][N],hd[N],V[N]; int nx[N],tt,ne[N],a[N],su[N],tr[N][26],fi[N]; char t[N],s[N],up[N],W[N];vector<int>e[N]; struct O{int i,u,v;};vector<O>p[N];queue<int>Q; void add(int u,int v,char c){ nx[++tt]=hd[u];V[hd[u]=tt]=v;W[tt]=c; } void dfs(int u,int fr){ dp[u]=dp[fa[0][u]=fr]+1; for (int i=1;fa[i-1][fa[i-1][u]];i++) fa[i][u]=fa[i-1][fa[i-1][u]]; for (int v,i=hd[u];i;i=nx[i]) if ((v=V[i])!=fr) up[v]=W[i],dfs(v,u); } int kmp(int n,int m){ if (n<m) return 0; ne[0]=ne[1]=0; for (int j,i=1;i<m;i++){ j=ne[i]; while(j && t[j]!=t[i]) j=ne[j]; if (t[j]==t[i]) ne[i+1]=j+1; else ne[i+1]=0; } int j=0,v=0; for (int i=0;i<n;i++){ while(j && s[i]!=t[j]) j=ne[j]; if (s[i]==t[j]) j++; if (j==m) v++; } return v; } int ins(int m){ int v=0; for (int i=0,j;i<m;i++){ j=t[i]-97; if (!tr[v][j]) tr[v][j]=++tt; v=tr[v][j]; } return v; } void build(){ for (int i=0;i<26;i++) if (tr[0][i]) Q.push(tr[0][i]); while(!Q.empty()){ int u=Q.front();Q.pop(); for (int v,i=0;i<26;i++){ v=tr[u][i]; if (v) fi[v]=tr[fi[u]][i],Q.push(v); else tr[u][i]=tr[fi[u]][i]; } } for (int i=1;i<=tt;i++) e[fi[i]].push_back(i); } void dfs(int u){ id[u]=++tt;sz[u]=1; int z=e[u].size(); for (int v,i=0;i<z;i++) v=e[u][i],dfs(v),sz[u]+=sz[v]; } int lca(int u,int v){ if (dp[u]<dp[v]) swap(u,v); for (int i=17;~i;i--) if (dp[fa[i][u]]>=dp[v]) u=fa[i][u]; if (u==v) return u; for (int i=17;~i;i--) if (fa[i][u]!=fa[i][v]) u=fa[i][u],v=fa[i][v]; return fa[0][u]; } void upd(int x,int v){ x=id[x]; for (;x<=tt;x+=x&-x) su[x]+=v; } int qry(int x){ int l=id[x]-1,r=id[x]+sz[x]-1,v=0; for (;r;r-=r&-r) v+=su[r]; for (;l;l-=l&-l) v-=su[l]; return v; } void dfs(int u,int fr,int k){ upd(k,1);int z=p[u].size(); for (int i=0;i<z;i++) a[p[u][i].i]+=p[u][i].v*qry(p[u][i].u); for (int v,i=hd[u];i;i=nx[i]) if ((v=V[i])!=fr) dfs(v,u,tr[k][W[i]-97]); upd(k,-1); } int Up(int u,int x){ if (x<0) return u; for (int i=17;~i;i--) if (x&(1<<i)) u=fa[i][u]; return u; } int main(){ cin>>n>>m; for (int u,v,i=1;i<n;i++) scanf("%d%d%s",&u,&v,t), add(u,v,t[0]),add(v,u,t[0]); dfs(1,0);tt=0; for (int i=1,u,v,len,p1,p2,w,u1,u2,z;i<=m;i++){ scanf("%d%d%s",&u,&v,t);z=lca(u,v); len=strlen(t);p1=ins(len); reverse(t,t+len);p2=ins(len);w=0; u1=u2=Up(u,dp[u]-dp[z]-len+1); for (int j=1;j<=dp[u1]-dp[z];j++) s[w++]=up[u2],u2=fa[0][u2]; if (dp[u]-dp[z]>=len) p[u1].push_back((O){i,p2,-1}), p[u].push_back((O){i,p2,1}); u1=u2=Up(v,dp[v]-dp[z]-len+1); w+=dp[u1]-dp[z]; for (int j=1;j<=dp[u1]-dp[z];j++) s[w-j]=up[u2],u2=fa[0][u2]; if (dp[v]-dp[z]>=len) p[u1].push_back((O){i,p1,-1}), p[v].push_back((O){i,p1,1}); reverse(t,t+len);a[i]+=kmp(w,len); } build();tt=0;dfs(0);dfs(1,0,0); for (int i=1;i<=m;i++) printf("%d ",a[i]); return 0; }