题解 (by;zjvarphi)
考虑对于母串的每个字符,它在匹配串中有多少前缀,多少后缀。
设 (f_i) 表示 (i) 位置匹配上的前缀,(g_i) 为后缀,那么答案为 (sum_{i=1}^{len}f_i×g_i)
那么如何求出 (f_i) 和 (g_i),考虑二分,求出一个最长的前缀,后缀。
在初始化时,将所有后缀记录上它的子后缀,前缀同理,用 (trie) 树就行,记得用 unordered_map
Code:
#include<bits/stdc++.h>
#define ri register signed
#define p(i) ++i
using namespace std;
namespace IO{
char buf[1<<21],*p1=buf,*p2=buf,OPUT[100];
#define gc() p1==p2&&(p2=(p1=buf)+fread(buf,1<<21,1,stdin),p1==p2)?(-1):*p1++
template<typename T>inline void read(T &x) {
ri f=1;x=0;register char ch=getchar();
while(!isdigit(ch)) {if (ch=='-') f=0;ch=getchar();}
while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
x=f?x:-x;
}
template<typename T>inline void print(T x) {
if (x<0) putchar('-'),x=-x;
if (!x) return putchar('0'),(void)putchar('
');
ri cnt(0);
while(x) OPUT[p(cnt)]=x%10,x/=10;
for (ri i(cnt);i;--i) putchar(OPUT[i]^48);
return (void)putchar('
');
}
}
using IO::read;using IO::print;
namespace nanfeng{
#define FI FILE *IN
#define FO FILE *OUT
template<typename T>inline T cmax(T x,T y) {return x>y?x:y;}
template<typename T>inline T cmin(T x,T y) {return x>y?y:x;}
typedef unsigned long long ull;
typedef long long ll;
static const int N=1e5+7,P=131,L=2e5+7;
int len1,len,n;
ll ans;
char s[N],s1[N];
ull p[N],h[N];
unordered_map<ull,int> mp1,mp2;
struct Trie{
#define Son(x,p) T[x].ch[p]
struct trie{int ch[26],nm;}T[L];
int tot;
Trie(){tot=1;}
inline void insert() {
ri cur=1;
for (ri i(1);i<=len;p(i)) {
ri ch=s[i]-'a';
if (!Son(cur,ch)) Son(cur,ch)=p(tot);
cur=Son(cur,ch);
p(T[cur].nm);
}
}
void dfs1(int nw,ull h) {
if (nw!=1&&T[nw].nm) mp1[h]=T[nw].nm;
for (ri i(0);i<26;p(i))
if (Son(nw,i)) {
T[Son(nw,i)].nm+=T[nw].nm;
dfs1(Son(nw,i),(ull)(i+1)+h*P);
}
}
void dfs2(int nw,ull h,int dep) {
if (nw!=1&&T[nw].nm) mp2[h]=T[nw].nm;
for (ri i(0);i<26;p(i))
if (Son(nw,i)) {
T[Son(nw,i)].nm+=T[nw].nm;
dfs2(Son(nw,i),(ull)(i+1)*p[dep]+h,dep+1);
}
}
}T1,T2;
inline int main() {
//FI=freopen("nanfeng.in","r",stdin);
//FO=freopen("nanfeng.out","w",stdout);
p[0]=1;
for (ri i(1);i<=N-7;p(i)) p[i]=p[i-1]*P;
scanf("%s",s1+1);
len1=strlen(s1+1);
for (ri i(1);i<=len1;p(i)) h[i]=h[i-1]*P+(ull)(s1[i]-'a'+1);
read(n);
ull k=-1;
for (ri i(1);i<=n;p(i)) {
scanf("%s",s+1);
len=strlen(s+1);
T1.insert();
reverse(s+1,s+len+1);
T2.insert();
}
T1.dfs1(1,0),T2.dfs2(1,0,0);
for (ri i(1);i<len1;p(i)) {
ri l(1),r(i),res(-1),tmp1(0),tmp2(0);
while(l<=r) {
int mid(l+r>>1);
if (mp2.find(h[i]-h[i-mid]*p[mid])!=mp2.end()) l=mid+1,res=mid;
else r=mid-1;
}
if (res!=-1) tmp2=mp2[h[i]-h[i-res]*p[res]];
l=1,r=len1-i,res=-1;
while(l<=r) {
int mid(l+r>>1);
if (mp1.find(h[i+mid]-h[i]*p[mid])!=mp1.end()) l=mid+1,res=mid;
else r=mid-1;
}
if (res!=-1) tmp1=mp1[h[i+res]-h[i]*p[res]];
ans+=1ll*tmp1*tmp2;
}
print(ans);
return 0;
}
}
int main() {return nanfeng::main();}