题意
给出(n)个字符串(s_1,s_2,dots,s_n),定义(f(s,t))为字符串(s)和字符串(t)的最长公共前后缀(字符串(s)的前缀,字符串(t)的后缀)。
让你计算
[sum_{i=1}^{n} sum_{j=1}^{n} f(s_i,s_j)^2~(mod~998244353)
]
分析
将所有字符串(s_i)插入AC自动机中,构建(fail)树,因为(fail)指针的含义是(fail[v])结点所代表的前缀串是(v)结点代表的前缀串能匹配到的最长后缀。所以字符串(s_i)的前缀能匹配到的后缀都在其所在的(fail)链上,根据(fail)指针反向建图,从根节点开始(dfs),(dfs)过程中对于每个(s_i)更新它的最长前缀,并统计答案,回溯时再撤销更新。
Code
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<sstream>
#include<cstdio>
#include<string>
#include<vector>
#include<bitset>
#include<queue>
#include<cmath>
#include<stack>
#include<set>
#include<map>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,p<<1|1
#define pii pair<int,int>
#define lson l,mid,p<<1
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=998244353;
const int N=1e6+10;
const int inf=1e9;
int n;
char s[N];
struct ACtree{
int son[N][26],fail[N],len[N],end[N],c[N],tot;
vector<int>q[N],g[N];
stack<pii>st;
ll res,ans;
int newnode(){
for(int i=0;i<26;i++) son[tot][i]=0;
end[tot++]=0;
return tot-1;
}
void init(){
ans=res=tot=0;
newnode();
}
void ins(char s[],int x){
int rt=0,m=strlen(s);
for(int i=0;i<m;i++){
if(!son[rt][s[i]-'a']) son[rt][s[i]-'a']=newnode();
rt=son[rt][s[i]-'a'];
len[rt]=i+1;
q[rt].pb(x);
}
end[rt]++;
}
void gao(){
queue<int>q;
for(int i=0;i<26;i++) if(son[0][i]) fail[son[0][i]]=0,q.push(son[0][i]);
while(!q.empty()){
int u=q.front();q.pop();
for(int i=0;i<26;i++){
if(son[u][i]){
fail[son[u][i]]=son[fail[u]][i];
q.push(son[u][i]);
}else son[u][i]=son[fail[u]][i];
}
}
for(int i=1;i<tot;i++) g[fail[i]].pb(i);
}
void dfs(int u){
for(int x:q[u]){
st.push(mp(x,c[x]));
res=(res-c[x]+mod)%mod;
c[x]=1ll*len[u]*len[u]%mod;
res=(res+c[x])%mod;
}
ans=(ans+1ll*res*end[u]%mod)%mod;
for(int x:g[u]){
dfs(x);
}
for(int x:q[u]){
pii tmp=st.top();
st.pop();
res=(res-c[tmp.fi]+mod)%mod;
c[tmp.fi]=tmp.se;
res=(res+c[tmp.fi])%mod;
}
}
ll qy(){
dfs(0);
return ans;
}
}AC;
int main(){
//ios::sync_with_stdio(false);
//freopen("in","r",stdin);
AC.init();
scanf("%d",&n);
rep(i,1,n){
scanf("%s",s);
AC.ins(s,i);
}
AC.gao();
printf("%lld
",AC.qy());
return 0;
}