最近做到好几道关于AC自动机与状态压缩dp的结合的题,这里总结一下。
题目一般会给出m个字符串,m不超过10,然后求长度为len并且包含特定给出的字符串集合的字符串个数。
以HDU 4758为例:
把题意抽象为:给出两个字符串,且只包含两种字符 'R'、'D',现在求满足下列条件的字符串个数:
1、字符串必须包含上述两个字符串。
2、字符串长度为(m+n),其中包含n个'D',m个'R'。
如果不用AC自动机来做,这道题还真没法做了,因为不管怎样都找不到正确的dp状态转移方程。
而如果引入AC自动机,把在AC自动机上的结点当做dp的一个维度的状态,那么问题就可解了。
dp[c][zt][i][j]:c表示当前状态的字符串对应于AC自动机上的结点,zt表示给定字符串取舍情况的压缩状态,i表示'D'的个数,j表示'R'的个数。
那么dp[c][zt][i][j]表示当前状态字符串的个数。
循环到dp[c][zt][i][j]时,其实dp[c][zt][i][j]已经被计算出来了,然后遍历trie树中c的所有子节点,计算它们的dp值。
最外层循环应该是字符串长度的循环,循环次数是题目要求的字符串长度,第二层循环是trie树中的所有节点,第三层是字符串取舍状态,最后是遍历c节点的所有子节点(说是子节点,其实是对c节点的下一个字符进行遍历,需要使用fail指针)。
c节点并不代表某个具体的字符串,它是指所有能到达c节点的字符串,dp的值就是保存这些字符串中满足条件的字符串个数。
AC自动机的作用就是增加一个状态维度,使dp过程有足够的信息来转移状态。
#include<cstdio> #include<cstring> #include<queue> using namespace std; const int mod = 1000000007; int ch[202][2],End[202],cur,fail[202],last[202]; void get_fail() { int now,tmpFail,Next; queue<int> q; for(int j=0;j<2;j++) { if(ch[0][j]) { q.push(ch[0][j]); fail[ch[0][j]] = 0; last[ch[0][j]] = 0; } } while(!q.empty()) { now = q.front();q.pop(); for(int j=0;j<2;j++) { if(!ch[now][j]) continue; Next = ch[now][j]; q.push(Next); tmpFail = fail[now]; while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail]; fail[Next] = ch[tmpFail][j]; last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]]; } } } int dp[202][4][102][102];//dp[c][zt][i][j] int main() { int T,m,n; char str0[3][104]; scanf("%d",&T); while(T--) { cur=1; scanf("%d%d",&m,&n); n++;m++; memset(End,0,sizeof(End)); memset(ch,0,sizeof(ch)); memset(last,0,sizeof(last)); for(int i=1;i<=2;i++) { scanf("%s",str0[i]); int len = strlen(str0[i]); int now = 0; for(int j=0;j<len;j++) { if(str0[i][j]=='R') str0[i][j]=1; else str0[i][j]=0; if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++; now = ch[now][str0[i][j]]; } End[now] = i; } get_fail(); memset(dp,0,sizeof(dp)); dp[0][0][0][0]=1; for(int i=0;i<n;i++) //要特别注意这里内外循环顺序,必须把i、j循环放在外面 for(int j=0;j<m;j++) { for(int c=0;c<cur;c++) { for(int zt=0;zt<=3;zt++){ if(dp[c][zt][i][j]) for(int k=0;k<2;k++) { if(k==0&&i==n-1) continue; else if(k==1&&j==m-1) continue; int now=c; while(now&&!ch[now][k]) now = fail[now]; now = ch[now][k]; int t=0; if(End[now]) t = t|(1<<(End[now]-1)); int tmp = now; while(last[tmp]) { t = t|(1<<(End[last[tmp]]-1)); tmp = last[tmp]; } if(k==0) { dp[now][zt|t][i+1][j] += dp[c][zt][i][j]; if(dp[now][zt|t][i+1][j]>=mod) dp[now][zt|t][i+1][j]-=mod; } else if(k==1) { dp[now][zt|t][i][j+1] += dp[c][zt][i][j]; if(dp[now][zt|t][i][j+1]>=mod) dp[now][zt|t][i][j+1]-=mod; } } } } } long long ans=0; for(int i=0;i<cur;i++) { ans+=dp[i][3][n-1][m-1]; if(ans>=mod) ans-=mod; } printf("%I64d ",ans); } }
注意循环的内外顺序,一般情况下,字符串长度的循环都是放在外层,也就是说,一定要先计算出长度为i的所有字符串状态,才能计算长度为i+1的所有字符串状态。
类似的 HDU 2825 :给 m 个单词构成的集合,求至少包含 k 个单词且长度为n的字符串个数。
#include<iostream> #include<algorithm> #include<cstring> #include<cstdio> #include<queue> using namespace std; const int mod=20090717; int ch[11*11][26],End[11*11],cur,fail[11*11],last[11*11]; char str0[12][12]; void get_fail() { int now,tmpFail,Next; queue<int> q; for(int j=0;j<26;j++) { if(ch[0][j]) { q.push(ch[0][j]); fail[ch[0][j]] = 0; last[ch[0][j]] = 0; } } while(!q.empty()) { now = q.front();q.pop(); for(int j=0;j<26;j++) { if(!ch[now][j]) continue; Next = ch[now][j]; q.push(Next); tmpFail = fail[now]; while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail]; fail[Next] = ch[tmpFail][j]; last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]]; } } } int dp[27][11*11][1055]; int main() { int sum[1055]; for(int I=0;I<(1<<10);I++) { sum[I]=0; int tmp=I; while(tmp) { if(tmp&1) sum[I]++; tmp>>=1; } } int n,m,k; while(scanf("%d%d%d",&n,&m,&k)!=EOF&&(n||m||k)) { cur=1; int len[13]; memset(End,0,sizeof(End)); memset(ch,0,sizeof(ch)); memset(last,0,sizeof(last)); for(int i=1;i<=m;i++) { scanf("%s",str0[i]); len[i] = strlen(str0[i]); int now = 0; for(int j=0;j<len[i];j++) { str0[i][j]-='a'; if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++; now = ch[now][str0[i][j]]; str0[i][j]+='a'; } End[now] = i; } get_fail(); memset(dp,0,sizeof(dp)); dp[0][0][0]=1; int pre=0,zt=0; int ans=0; for(int i=0;i<n;i++) { for(int j=0;j<cur;j++) { for(int zt=0;zt<(1<<m);zt++) { if(dp[i][j][zt]) { for(int c=0;c<26;c++) { int now = j; while(now&&!ch[now][c]) now = fail[now]; now = ch[now][c]; int t=0; if(End[now]) t = t|(1<<(End[now]-1)); int tmp = now; while(last[tmp]) { t = t|(1<<(End[last[tmp]]-1)); tmp = last[tmp]; } dp[i+1][now][zt|t] += dp[i][j][zt]; if(dp[i+1][now][zt|t]>=mod) dp[i+1][now][zt|t]-=mod; } } } } } for(int I=0;I<(1<<m);I++) { if(sum[I]>=k) { for(int j=0;j<cur;j++){ ans+=dp[n][j][I]; if(ans>=mod) ans-=mod; } } } printf("%d ",ans); } }
HDU 4057:给出一些模式串,每个串有一定的价值,现在构造一个长度为M的串,问最大的价值为多少,每个模式串最多统计一次。
#include<cstdio> #include<cstring> #include<queue> using namespace std; int ch[11*102][4],End[11*102],cur,fail[11*102],last[11*102]; int w[11]; char str[102],str0[11][102]; void get_fail() { int now,tmpFail,Next; queue<int> q; //用bfs生成fail //初始化队列 for(int j=0; j<4; j++) { if(ch[0][j]) { q.push(ch[0][j]); fail[ch[0][j]] = 0; last[ch[0][j]] = 0; } } while(!q.empty()) { //从队列中拿出now //此时now中的fail、last已经算好了 //下面计算的是ch[now][j]中的fail、last。 now = q.front(); q.pop(); for(int j=0; j<4; j++) { if(!ch[now][j]) continue; Next = ch[now][j]; q.push(Next); tmpFail = fail[now]; while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail]; fail[Next] = ch[tmpFail][j]; last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]]; } } } int dp[1029][11*102][2]; bool vis[1029][11*102][2]; int n,l,now,ans; queue<int> quezt; queue<int> quenow; queue<int> quelen; void bfs (int zt,int now0,int len) { //printf("%d %d %d %d ",zt,now0,len,dp[zt][now0][len%2]); //printf("%d ",quezt.size()); if(len==l) ans=max(ans,dp[zt][now0][l%2]); if(len==l+1) return; for(int i=0; i<4; i++) { int now=now0,temp=0; while(now&&!ch[now][i]) now = fail[now]; now = ch[now][i]; int newzt = zt; if(End[now]) { if(((1<<(End[now]-1))|newzt)!=newzt) temp+=w[End[now]]; newzt = (1<<(End[now]-1))|newzt; } int tmp = now; while(last[tmp]) { if(End[last[tmp]]) { if(((1<<(End[last[tmp]]-1))|newzt)!=newzt) temp+=w[End[last[tmp]]]; newzt = (1<<(End[last[tmp]]-1))|newzt; } tmp = last[tmp]; } if(newzt!=zt) { //printf("%d ",temp); if(!vis[newzt][now][(len+1)%2]) dp[newzt][now][(len+1)%2]=dp[zt][now0][len%2]+temp; else dp[newzt][now][(len+1)%2]=max(dp[zt][now0][len%2]+temp,dp[newzt][now][(len+1)%2]); } else{ if(!vis[zt][now][(len+1)%2]) dp[zt][now][(len+1)%2]=dp[zt][now0][len%2]; else dp[zt][now][(len+1)%2]=max(dp[zt][now0][len%2],dp[zt][now][(len+1)%2]); } //dfs(newzt,now,len+1); if(!vis[newzt][now][(len+1)%2]) { quezt.push(newzt); quenow.push(now); quelen.push(len+1); vis[newzt][now][(len+1)%2]=true; } } //if(len==l) ans=max(ans,dp[zt][now0][l%2]); } int main() { while(scanf("%d%d",&n,&l)!=EOF) { memset(dp,-1,sizeof(dp)); memset(ch,0,sizeof(ch)); memset(End,0,sizeof(End)); memset(last,0,sizeof(last)); cur = 1; int len; for(int i=1; i<=n; i++) { scanf("%s%d",str0[i],&w[i]); //puts(str0[i]); len = strlen(str0[i]); now = 0; for(int j=0; j<len; j++) { if(str0[i][j]=='A') str0[i][j]=0; if(str0[i][j]=='T') str0[i][j]=1; if(str0[i][j]=='G') str0[i][j]=2; if(str0[i][j]=='C') str0[i][j]=3; if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++; now = ch[now][str0[i][j]]; if(str0[i][j]==0) str0[i][j]='A'; if(str0[i][j]==1) str0[i][j]='T'; if(str0[i][j]==2) str0[i][j]='G'; if(str0[i][j]==3) str0[i][j]='C'; } End[now] = i; } //printf("%d ",cur); get_fail(); //printf("%d ",cur); dp[0][0][0]=0; quezt.push(0); quenow.push(0); quelen.push(0); memset(vis,false,sizeof(vis)); vis[0][0][0]=true; ans=-1; int pre=0; while(!quezt.empty()) { //if(quelen.front()!=pre) { // for(int i=0;i<1029;i++) // for(int j=0;j<11*102;j++) dp[i][j][pre%2]=0; // pre=quelen.front(); //} bfs(quezt.front(),quenow.front(),quelen.front()); vis[quezt.front()][quenow.front()][quelen.front()%2]=false; quezt.pop();quenow.pop();quelen.pop(); } if(ans==-1) puts("No Rabbit after 2012!"); else printf("%d ",ans); } }