题目描述
小张最近在忙毕设,所以一直在读论文。一篇论文是由许多单词组成但小张发现一个单词会在论文中出现很多次,他想知道每个单词分别在论文中出现了多少次。
输入输出格式
输入格式:第一行一个整数N,表示有N个单词。接下来N行每行一个单词,每个单词都由小写字母(a-z)组成。(N≤200)
输出格式:输出N个整数,第i行的数表示第i个单词在文章中出现了多少次。
输入输出样例
3
a
aa
aaa
6
3
1
说明
数据范围
30%的数据, 单词总长度不超过10^3
100%的数据,单词总长度不超过10^6
Solution:
本题其实不难,数据也很水,可以各种操作将其A穿。
最开始是刷AC自动机的题目进来的,所以开始就往AC自动机上面想,初想法是直接将串构建trie树,对于每个节点,建树时用num[i]统计每个节点有多少个字符串经过,记录一下每个字符串的尾节点,然后在bfs求失配边时,考虑到fail的性质:fail[u]是u的最长后缀所在节点(即以fail[u]结尾的字符串一定在u结尾的字符串中出现),于是u对fail[u]的贡献为num[u],那么在广搜时直接对当前节点暴力向上扫fail数组并更新num值,结果发现这样就A了($360ms$)。
后面意识到一个严重的问题,当字典树为一条链形如:aaaaaaa…这种情况时,每个节点的fail指向的就是其上一个节点,那么我开始的更新num的方法就卡成了$O(n^2)$的了。
于是,想到fail实际上是一个树形结构,我们可以用一个栈,在广搜时每次压入当前节点,最后弹栈时因为栈顶节点深度一定大于等于栈内节点深度,所以可以类树形dp一样从下往上更新,每次只要更新一次,更新的时间复杂度就是$O(n)$的了,然后AC($128ms$)。
然后发现题解内的方法五花八门,有人用hashA了,于是又去水hash,卡空间不慌我们直接将所有串连起来,中间加个$#$隔开就好了,但是发现本题居然卡单模数$998244353$(20分)、$19260817$(0分)、$1004535809$(20分),而不想搞多模数,果断unsigned long long自然溢出就好了,最后也A了($656ms$)。
代码:
AC自动机版:
#include<bits/stdc++.h> #define il inline #define ll long long #define For(i,a,b) for(int (i)=(a);(i)<=(b);(i)++) #define Bor(i,a,b) for(int (i)=(b);(i)>=(a);(i)--) using namespace std; const int N=1000005; int n,trie[N][26],end[205],cnt,fail[N],num[N]; char s[N]; int Q[N],tot; il void insert(char *s,int id){ int len=strlen(s)-1,p=0,x; For(i,0,len){ x=s[i]-'a'; if(!trie[p][x]) trie[p][x]=++cnt; p=trie[p][x]; num[p]++; } end[id]=p; } il void bfs(){ queue<int>q; For(i,0,25) if(trie[0][i]) fail[trie[0][i]]=0,q.push(trie[0][i]); while(!q.empty()){ int u=q.front();q.pop(); Q[++tot]=u; For(i,0,25){ int v=trie[u][i]; if(v) fail[v]=trie[fail[u]][i],q.push(v); else trie[u][i]=trie[fail[u]][i]; } } Bor(i,1,tot) num[fail[Q[i]]]+=num[Q[i]]; } int main(){ scanf("%d",&n); For(i,1,n) scanf("%s",s),insert(s,i); bfs(); For(i,1,n) printf("%d ",num[end[i]]); return 0; }
hash版:
#include<bits/stdc++.h> #define il inline #define ll long long #define For(i,a,b) for(int (i)=(a);(i)<=(b);(i)++) #define Bor(i,a,b) for(int (i)=(b);(i)>=(a);(i)--) using namespace std; const int P=131,mod=998244353; unsigned ll Hash[1000205],hash[205],sum[1000205]; int n,cnt,len[205],ans[205]; char s[1000005]; int main(){ scanf("%d",&n);sum[0]=1; For(i,1,n){ scanf("%s",s+1); len[i]=strlen(s+1); For(j,1,len[i]) sum[++cnt]=sum[cnt-1]*P, Hash[cnt]=Hash[cnt-1]*P+s[j], hash[i]=hash[i]*P+s[j]; sum[++cnt]=sum[cnt-1]*P, Hash[cnt]=Hash[cnt-1]*P+'#'; } For(i,1,n) For(j,1,cnt-len[i]) if(Hash[j+len[i]-1]-Hash[j-1]*sum[len[i]]==hash[i]) ans[i]++; For(i,1,n) printf("%d ",ans[i]); return 0; }