题目大意:给出m个疾病基因片段(m<=10),每个片段不超过10个字符。求长度为n的不包含任何一个疾病基因片段的DNA序列共有多少种?(n<=2000000000)
分析:本题需要对m个疾病基因片段构建一个AC自动机,自动机中的每个节点表示一个状态。其中AC自动机中的叶子节点表示的是病毒,所以是非法状态。同时,如果某个节点到根的字符串的后缀是一个病毒,那么该节点也是非法状态。剔除掉所有的非法状态,那么剩下的节点都表示合法状态了。然后用节点的nxt指针表示状态之间转化关系。若nxt[i]==0,则nxt[i]指针指向当前节点fail指针的nxt[i],如果仍然为0,则nxt[i]指向根节点。这样处理以后,每个指针都不会指向0。这样,该自动机可以看做是一个合法状态的转换图,节点表示各种合法状态,边表示添加一个字符将转换为另一个状态。于是我们可以得到一个矩阵。该矩阵实际上表示该状态图的邻接矩阵。对该矩阵自乘n次。最后结果矩阵的第1行各元素之和表示从空状态添加n个字符能够得到的所有合法状态的数量。
矩阵的思想非常巧妙。
#include<iostream> #include<cstdio> #include<cstring> using namespace std; #define MAXN 102 #define MAXL 12 #define MAXC 4 #define MOD 100000 struct node { int fail,nxt[6],flag; }trie[MAXN]; int head,tail,myq[MAXN],root=1,tot=1; char word[MAXL]; int degree; int a[MAXN][MAXN],b[MAXN][MAXN],c[MAXN][MAXN],(*ans)[MAXN]; void multi(int (*a)[MAXN],int (*b)[MAXN],int (*c)[MAXN]) { for(int i=1;i<=degree;i++) { for(int j=1;j<=degree;j++) c[i][j]=0; } for(int i=1;i<=degree;i++) { for(int j=1;j<=degree;j++) { for(int k=1;k<=degree;k++) { c[i][j]+=(long long)a[i][k]*b[k][j]%MOD; c[i][j]%=MOD; } } } } void power(int (*t1)[MAXN],int h) { for(int i=1;i<=degree;i++) for(int j=1;j<=degree;j++) b[i][j]=0; for(int i=1;i<=degree;i++) b[i][i]=1; int (*t2)[MAXN],(*t3)[MAXN]; t2=b,t3=c; while(h) { if(h&1) {multi(t1,t2,t3); swap(t2,t3); } h>>=1; multi(t1,t1,t3); swap(t1,t3); } if(t2!=a) { memcpy(a,t2,sizeof a); } } int inline getid(char C) { if(C=='A')return 0; else if(C=='T')return 1; else if(C=='C')return 2; else return 3; } void insert(int r,char *s) { int len=strlen(s); for(int i=0;i<len;i++) { int val=getid(s[i]); if(trie[r].nxt[val]==0) trie[r].nxt[val]=++tot; r=trie[r].nxt[val]; } trie[r].flag=1;//1表示结束节点 } void build(int r) { trie[r].fail=r; myq[tail++]=r; int ch; while(head<tail) { r=myq[head++]; for(int i=0;i<MAXC;i++) { ch=trie[r].nxt[i]; if(ch)myq[tail++]=ch; if(r==root) { if(ch) trie[ch].fail=root; else trie[r].nxt[i]=root; } else { if(ch) {trie[ch].fail=trie[trie[r].fail].nxt[i]; trie[ch].flag|=trie[trie[ch].fail].flag; } else trie[r].nxt[i]=trie[trie[r].fail].nxt[i]; } ch=trie[r].nxt[i]; if(trie[ch].flag!=1) a[r][ch]++; } } } int main() { int n,m; scanf("%d%d",&m,&n); for(int i=0;i<m;i++) { scanf("%s",word); insert(root,word); } build(root); degree=tot; power(a,n); int ans=0; for(int i=1;i<=degree;i++) {ans+=a[1][i]; ans%=MOD; } printf("%d ",ans); }