鸣谢https://blog.csdn.net/Baoli1008/article/details/4441936,基本都是抄的代码
#pragma GCC optimize(2) #include <bits/stdc++.h> #define ll long long using namespace std; const int MAXN = 4e6+10; int idx(char c) { return c-'0'; } struct node{ ll num; char ch; node(){ num=0; }; }; vector<int>G[MAXN]; node N[MAXN]; int cn; void init() { for(int i=0;i<cn;i++) { G[i].clear(); N[i].num=0; } cn=1; } void input(char s[]) { int cur=0; int len=strlen(s); N[cur].num++; for(int i=0;i<=len;i++) { int nxt=0; int m=G[cur].size(); for(int j=0;j<m;j++) { int tmp=G[cur][j]; if(N[tmp].ch==s[i]) { nxt=tmp; break; } } if(!nxt) { N[cn].ch=s[i]; G[cur].push_back(cn); nxt=cn; cn++; } N[nxt].num++; cur=nxt; } } ll solve(int n) { ll res=0; ll tmp=0; if(n) res+=N[n].num*(N[n].num-1); int m=G[n].size(); for(int i=0;i<m;i++) { int t=G[n][i]; res+=solve(t); tmp+=(N[n].num-N[t].num)*N[t].num; } res+=tmp/2; return res; } int main() { int T; int kase=0; while(cin>>T&&T) { init(); for(int i=0;i<T;i++) { char s[1005]; scanf("%s",s); input(s); } printf("Case %d: %lld ",++kase,solve(0)); } return 0; }