【题意】给定n个原串和m个禁忌串,要求用原串集合能拼出的不含禁忌串且长度为L的串的数量。(60%)n,m<=50,L<=100。(40%)原串长度为1或2,L<=10^18。
【算法】AC自动机+DP+矩阵快速幂
【题解】其实题意的数据范围不太清晰,反正开200个点就足够了。
因为要匹配禁忌串,所以对禁忌串集合建立AC自动机,标记禁忌串结尾节点,以及下传到所有能fail到的点(这些点访问到都相当于匹配了禁忌串)。
令f[i][j]表示匹配到节点i,长度为j的串的数量,先预处理a[i][j]表示节点 i 匹配第 j 个原串到达的节点编号,那么就有:
f [ a[i][j] ] [ L+size[j] ] += f [ i ] [ L ]
以上就是60%数据的做法,对于40%的数据使用矩阵快速幂。
假设原串长度均为1,那么DP的转移如下:
$$f[i][L]=sum_{j}f[j][L-1] , j ightarrow i$$
这很容易用一个长度为第一维大小(AC自动机节点数)的矩阵维护转移,第L个列向量就是f[i][L]。
如果原串长度有2,那么再记录L-1即可。
#include<cstdio> #include<cstring> #include<queue> #include<algorithm> #define ll long long using namespace std; const int maxn=5010,MOD=1e9+7; int n,m,a[maxn][110],ch[maxn][27],val[maxn],size[maxn],sz=0,fail[maxn]; ll L; char s[110][maxn],S[maxn]; queue<int>Q; void insert(char *s){ int n=strlen(s),u=0; for(int i=0;i<n;i++){ int c=s[i]-'a'; if(!ch[u][c])ch[u][c]=++sz; u=ch[u][c]; } val[u]++; } void AC_build(){ for(int c=0;c<26;c++)if(ch[0][c])Q.push(ch[0][c]); while(!Q.empty()){ int u=Q.front();Q.pop(); for(int c=0;c<26;c++)if(ch[u][c]){ fail[ch[u][c]]=ch[fail[u]][c]; Q.push(ch[u][c]); val[ch[u][c]]|=val[fail[ch[u][c]]];// } else ch[u][c]=ch[fail[u]][c]; } } int M(int x){return x>=MOD?x-MOD:x;} namespace Task1{ int f[maxn][110]; void solve(){ f[0][0]=1; for(int l=0;l<L;l++){// for(int i=0;i<=sz;i++)if(f[i][l]){ for(int j=1;j<=n;j++)if(~a[i][j]&&l+size[j]<=L){ f[a[i][j]][l+size[j]]=M(f[a[i][j]][l+size[j]]+f[i][l]); } } } int ans=0; for(int i=0;i<=sz;i++)if(f[i][L]&&!val[i])ans=M(ans+f[i][L]); printf("%d",ans); } } namespace Task2{ const int maxn=110; int N,A[maxn*2][maxn*2],ANS[maxn*2][maxn*2],c[maxn*2][maxn*2]; void mul(int a[maxn*2][maxn*2],int b[maxn*2][maxn*2]){ for(int i=0;i<=N;i++){ for(int j=0;j<=N;j++){ c[i][j]=0; for(int k=0;k<=N;k++)c[i][j]=M(c[i][j]+1ll*a[i][k]*b[k][j]%MOD); } } for(int i=0;i<=N;i++)for(int j=0;j<=N;j++)b[i][j]=c[i][j]; } void solve(){ N=sz*2+1; for(int i=0;i<=sz;i++){ for(int j=1;j<=n;j++)if(~a[i][j]){ if(size[j]==1)A[a[i][j]*2][i*2]++; else A[a[i][j]*2][i*2+1]++; } A[i*2+1][i*2]=1; } ANS[0][0]=1; while(L){ if(L&1)mul(A,ANS); mul(A,A); L>>=1; } int ans=0; for(int i=0;i<=sz;i++)if(!val[i])ans=M(ans+ANS[i*2][0]); printf("%d",ans); } } int main(){ scanf("%d%d%lld",&n,&m,&L); for(int i=1;i<=n;i++)scanf("%s",s[i]); for(int i=1;i<=m;i++){ scanf("%s",S); insert(S); } AC_build(); memset(a,-1,sizeof(a)); for(int k=1;k<=n;k++){ size[k]=strlen(s[k]); for(int i=0;i<=sz;i++){ int u=i; for(int j=0;j<size[k];j++)if(!val[u])u=ch[u][s[k][j]-'a'];else break; if(!val[u])a[i][k]=u; } } if(L<=100)Task1::solve();else Task2::solve(); return 0; }