题意:首先输入三个整数:n,m,p;n代表总共有n个字母,m代表字符串的长度为m,p代表病毒字符串的个数;题目让你求的是不包含病毒的字符串长度为m的个数为多少。
思路:这个题让我意识到我自己是有多么的水啊,中南一大牛花了一天的时间自己想出来并且解决它,而我花费了一个星期的时间把别人的代码看懂了,啊!!!分析下其中的原因:最重要的原因是自己知识的漏洞太多了,开始的时候对于Trie图是一无所知,在请教了中南大牛之后,以为自己弄懂了,可是在DP的时候,我又在想为什么能够这样DP呢???归根结底还是对Trie图没理解好。首先我讲下现在的理解吧:1.目的:把各个长度的字符串都能够在树上反映出来,但是要把那些含有病毒的字符串去掉,现在最重要的就是如何去掉这些含有病毒的字符串了,那就要看Trie图的强大之处了;2.如何建立Trie图(如图二所示)呢?首先建立好字典树(如图一所示),(首先声明我说的是没有优化的Trie图)从根节点开始,每个白色的节点(安全节点)都需要n条边,就拿图二中的第二层b节点为例吧!它有指向a的边,也有指向b的边,可是没有指向c的边,怎么办呢??自己建啊!如何建呢?如下:假设b节点后有c节点,那么在AC自动机中,c节点的fail指针应该指向和b节点相邻的c节点,说明这个c(和b相邻的)节点与那个c(b节点的孩子)是等价的,他们的安全性是一样的,那么我们就把b节点直接指向与它相邻的c节点建立一条边(这个东西在后面的DP中很重要,开始的时候我是一直没想明白的,经过两三天的思考才弄懂的),其它的白色节点也是这样建立的,那么我们的Trie图就这样愉快的建好了,其实说到底就是在AC自动机上的那个图上多建立了一些边;3.如何DP呢??首先要把所有长度为(1--m)的字符串根据这个Trie图分类,怎么分类??假设有Trie图中有t个节点,那么我们可以根据这t个节点发出一条边到达不同的字符串分类,就是把他们分成t类,最后只要统计这t个类中长度为m的字符串即可!!状态转移方程为:dp[i][j]=dp[i-1][k]+dp[i][j],其中的k代表在第i步有一条指向j节点的边,其实这个j节点分为两类:一类是k节点有后继节点j;另一类是k节点其实没有后继节点,而是通过找等价节点找到的j节点,这一类是最难理解的,我是用了很长时间才把他理解透彻的啊!!!最后要注意的是:这个题目最后的结果是大数,而题目中也没有说明要取模,所以要用精度计算。废话我也不说了,看代码实现吧!!做完这个题之后建议做:poj 2778(听说要用矩阵乘法,我表示我不会)、hdu 2243
#include<iostream> #include<queue> #include<string.h> #include<algorithm> #include<stdlib.h> using namespace std; struct node{ int flag; int fail; int next[50]; void init() { flag=0; fail=0; memset(next,0,sizeof(next)); } }s[110]; int n,m,t,tot; int dp[51][110][30]; char str[51]; void ca()//初始化 { tot=0;//统计节点个数 s[0].init(); } int cmp(char a,char b) { if(a>b) return 0; else return 1; } int hash(char temp)//找到节点在字符集中的位置 { int f=0,r=n,mid; while(f<=r) { mid=(f+r)/2; if(temp>str[mid]) f=mid+1; else if(temp==str[mid]) return mid; else r=mid-1; } } void insert(char *str)//建立字典树 { int p=0,index; for(;*str!='\0';str++) { index=hash(*str); if(s[p].next[index]==0) { s[++tot].init(); s[p].next[index]=tot; } p=s[p].next[index]; } s[p].flag=1; } void AC_tree()//建立Trie图 { int p,cur,son,i; queue<int>Q; s[0].fail=0; Q.push(0); while(!Q.empty()) { p=Q.front(); Q.pop(); for(i=0;i<n;i++) { son=s[p].next[i]; if(son!=0) { if(p==0) s[son].fail=0; else { cur=s[p].fail; while(cur!=0&&s[cur].next[i]==0) cur=s[cur].fail; s[son].fail=s[cur].next[i]; } if(s[s[son].fail].flag==1) s[son].flag=1; Q.push(son); } else s[p].next[i]=s[s[p].fail].next[i]; } } } void add(int *a,int *b)//大数相加 { int i,c=0; for(i=0;i<30;i++) { a[i]=a[i]+b[i]+c; c=a[i]/10000; a[i]=a[i]%10000; } } void solve() { int i,j,k,p; memset(dp,0,sizeof(dp)); dp[0][0][0]=1; for(i=1;i<=m;i++) { for(j=0;j<=tot;j++) { if(s[j].flag)//去掉危险节点 continue; for(k=0;k<n;k++) { p=s[j].next[k]; if(s[p].flag)//去掉危险节点 continue; add(dp[i][p],dp[i-1][j]); } } } int ans[30]; memset(ans,0,sizeof(ans)); for(i=0;i<=tot;i++) if(s[i].flag==0) add(ans,dp[m][i]); for(i=29;i>=0;i--) if(ans[i]!=0) break; if(i<0) printf("0"); else { printf("%d",ans[i]); i--; for(;i>=0;i--) printf("%04d",ans[i]); } printf("\n"); } int main() { char haha[51]; while(scanf("%d%d%d",&n,&m,&t)!=EOF) { getchar(); ca(); gets(str); sort(str,str+strlen(str),cmp); while(t--) { gets(haha); insert(haha); } AC_tree(); solve(); } return 0; }