就是POJ2778的加强版,思路是一样的。
出现过给定的单词的单词数=总单词数-未出现过给定单词的单词数,前者等于26^1+26^2+26^3....26^l,后者等于Mat^1+Mat^2+..Mat^l,这里的Mat是根据能走的字符之间的路径数建立出来的邻接矩阵。两者求出来一减就可以了,对2^64取模可以忽略,直接用unsigned long long做,忽视溢出就等于取模了,注意减法最后结果可能小于0,要加2^64变成正数。
求a^1+..a^n用一次二分就可以,建议这种二分都写成非递归的,效率比较高而且不会有爆栈的隐患。
#include <stdio.h> #include <string.h> #define MAXN 31 typedef unsigned long long LL; LL dmat[MAXN][MAXN]; struct matrix{ LL mz[MAXN][MAXN]; int n; #define FOR(i) for(int i=0;i<n;i++) matrix(int nn,int type):n(nn){ if(type==0)FOR(i)FOR(j)mz[i][j]=0; else if(type==1)FOR(i)FOR(j)mz[i][j]=(i==j?1:0); else FOR(i)FOR(j)mz[i][j]=dmat[i][j]; } matrix operator *(const matrix& b)const{ matrix ans(n,0); FOR(i)FOR(j)if(mz[i][j]) FOR(k)ans.mz[i][k]+=mz[i][j]*b.mz[j][k]; return ans; } matrix operator +(const matrix& b)const{ matrix ans(n,0); FOR(i)FOR(j)ans.mz[i][j]=mz[i][j]+b.mz[i][j]; return ans; } //M^1+M^2+M^3...M^N matrix binPlusMul(int x){ matrix ans(n,0),tmp(n,2),mat(n,2); int bit=1; for(LL i=(x>>1);i;i>>=1,bit<<=1);bit>>=1; for(ans=mat;bit;bit>>=1){ ans=ans+tmp*ans; tmp=tmp*tmp; if(bit&x){ tmp=tmp*mat; ans=ans+tmp; } } return ans; } }; int next[MAXN][26],fail[MAXN],flag[MAXN],pos; int newnode(){ for(int i=0;i<26;i++)next[pos][i]=0; fail[pos]=flag[pos]=0; return pos++; } void insert(char *s){ int p=0; for(int i=0;s[i];i++){ int k=s[i]-'a',&x=next[p][k]; p=x?x:x=newnode(); } flag[p]=1; } int q[MAXN],front,rear; void makenext(){ q[front=rear=0]=0,rear++; while(front<rear){ int u=q[front++]; for(int i=0;i<26;i++){ int v=next[u][i]; if(flag[v])continue; if(v==0)next[u][i]=next[fail[u]][i]; else q[rear++]=v; if(u&&v){ fail[v]=next[fail[u]][i]; if(flag[fail[v]])flag[v]=1; } } } } LL cal(int x,int n){ int bit=1; LL ans=0,tmp=x; for(int i=(n>>1);i;i>>=1,bit<<=1);bit>>=1; for(ans=x;bit;bit>>=1){ ans=ans+tmp*ans; tmp=tmp*tmp; if(bit&n){ tmp=tmp*x; ans=ans+tmp; } } return ans; } int n; LL l; char s[10]; int main(){ //freopen("test.in","r",stdin); while(scanf("%d%d",&n,&l)!=EOF){ pos=0;newnode(); memset(dmat,0,sizeof dmat); for(int i=0;i<n;i++){ scanf("%s",s); insert(s); } makenext(); for(int u=0;u<pos;u++){ if(flag[u])continue; for(int i=0;i<26;i++){ int v=next[u][i]; if(flag[v])continue; dmat[u][v]++; } } matrix mt(pos,2); matrix mat=mt.binPlusMul(l); LL ans=cal(26,l); for(int i=0;i<pos;i++)ans-=mat.mz[0][i]; if(ans<0)ans+=((LL)1<<63)+((LL)1<<63); printf("%I64u\n",ans); } return 0; }