题目链接:HDU-5129
题目大意为给一堆字符串,问由任意两个字符串的前缀子串(注意断句)能组成多少种不同的字符串。
思路是先用总方案数减去重复的方案数。
考虑对于一个字符串S,如图,假设S1,S2,S3,S4,S5,S6均为前缀。
换言之,对于这种字符串,我们计算了三次。
发现,重复的方案数,等于中间如图有颜色的方块的数量。所以我们要做的也就是计数像图中有颜色的小方块的数量。
我们可以通过遍历像S6一样的字符串的数量,来计算重复的方案数。S6满足以下条件:
- 存在一个前缀S4为S6的后缀,且S4为S6的最长的“是前缀的”后缀(保证是小方块间没有隔板)。
- S6-S4为某个前缀的后缀(在图中为S3的后缀)。
我们发现,对于某一个S6,其对应的S4时一定的。
我们用sum[S]表示以S为后缀的前缀字符串的数量。
这样,当我们遍历S6,对于每一个S6,我们找到其对应的S4后,只要减去( sum[S6-S4] - 1 )即可。
所以我们要处理的,就是以下两个问题:
- 如何找到S6对应的S4
- 如何求sum[S]
这时候,我们发现S6与S4的关系,和AC自动机的性质很像。S6在失配时跳到的位置就是S4。于是第一个问题解决了。
对于第二个问题,我们同样可以通过AC自动机找到。具体的参见代码。
1 #include<cstring> 2 #include<cstdio> 3 #include<queue> 4 using namespace std; 5 6 typedef long long LL; 7 const LL MAXN=300010; 8 const LL SIGMA_SIZE=26; 9 struct Trie 10 { 11 LL ch[MAXN][SIGMA_SIZE]; 12 LL fa[MAXN]; 13 LL sz; //节点总数 14 Trie() { sz=1; fa[0]=-1; memset(ch[0],0,sizeof(ch[0])); } 15 LL idx(char c) { return c-'a'; } //节点c的编号 16 void clear() { sz=1; fa[0]=-1; memset(ch[0],0,sizeof(ch[0])); } 17 18 //在Trie中插入字符串s 19 void insert(char *s) 20 { 21 LL u=0,n=strlen(s); 22 for(LL i=0;i<n;i++) 23 { 24 LL c=idx(s[i]); 25 if(!ch[u][c]) //节点不存在 26 { 27 memset(ch[sz],0,sizeof(ch[sz])); 28 fa[sz]=u; 29 ch[u][c]=sz++; //新建节点 30 } 31 u=ch[u][c]; //往下走 32 } 33 } 34 35 //AC自动机部分 36 LL f[MAXN]; 37 LL deg[MAXN]; 38 LL sum[MAXN]; //sum[i]表示以Si为后缀的前缀的数量 39 void getFail() 40 { 41 queue<LL> q; 42 f[0]=0; 43 //初始化队列 44 for(LL c=0;c<SIGMA_SIZE;c++) 45 { 46 LL u=ch[0][c]; 47 if(u) { f[u]=0; q.push(u); } 48 } 49 //按BFS顺序计算失配函数 50 while(!q.empty()) 51 { 52 LL r=q.front(); q.pop(); 53 for(LL c=0;c<SIGMA_SIZE;c++) 54 { 55 LL u=ch[r][c]; 56 if(!u) { ch[r][c]=ch[f[r]][c]; continue; } 57 q.push(u); 58 LL v=f[r]; 59 f[u]=ch[v][c]; 60 } 61 } 62 for(LL i=1;i<sz;i++) { deg[i]=0; sum[i]=1; } 63 for(LL i=1;i<sz;i++) deg[f[i]]++; 64 queue<LL> Q; 65 for(LL i=1;i<sz;i++) if(!deg[i]) Q.push(i); 66 while(!Q.empty()) 67 { 68 LL u=Q.front(); Q.pop(); 69 sum[f[u]]+=sum[u]; 70 deg[f[u]]--; 71 if(!deg[f[u]]) Q.push(f[u]); 72 } 73 } 74 75 void solve() 76 { 77 LL tot=0; 78 for(LL i=1;i<sz;i++) if(f[i]) 79 { 80 LL j=f[i]; 81 LL p=i; 82 while(j) 83 { 84 p=fa[p]; 85 j=fa[j]; 86 } 87 tot+=sum[p]-1; 88 } 89 printf("%lld ",1LL*(sz-1)*(sz-1)-tot); 90 } 91 }; 92 Trie T; 93 int main() 94 { 95 #ifdef LOCAL 96 freopen("in.txt","r",stdin); 97 #endif 98 LL n; 99 while(scanf("%lld",&n) && n) 100 { 101 scanf("%lld",&n); 102 for(LL i=1;i<=n;i++) 103 { 104 char s[50]; 105 scanf("%s",s); 106 T.insert(s); 107 } 108 T.getFail(); 109 T.solve(); 110 T.clear(); 111 } 112 return 0; 113 }