用后缀树统计出出现了x次的本质不同的子串的个数,最后再乘以x,得到一个多项式。
这个多项式常数项为0,但是一次项不为0。
于是把整个多项式除以一次项,通过多项式求ln和多项式求exp求出它的幂。
最后再把除掉的项乘回来即可,时间复杂度$O(nlog n)$。
#include<cstdio> #include<cstring> typedef long long ll; const int N=262144,K=17,inf=~0U>>2,S=27,M=200010,P=1005060097,G=5; char s[M]; int n,m,x,i,j,k,C; int a[N+10],b[N+10],tmp[N],tmp2[N],g[K+1],ng[K+1],inv[N+10],inv2; int text[M],root,last,pos,need,remain,acnode,ace,aclen,size[M]; inline int min(int a,int b){return a<b?a:b;} struct node{int st,en,lk,son[S];inline int len(){return min(en,pos+1)-st;}}tree[M]; inline int new_node(int st,int en=inf){return tree[++last].st=st,tree[last].en=en,last;} inline int acedge(){return text[ace];} inline void addedge(int node){ if(need)tree[need].lk=node; need=node; } inline bool down(int node){ if(aclen>=tree[node].len())return ace+=tree[node].len(),aclen-=tree[node].len(),acnode=node,1; return 0; } inline void init(){ need=last=remain=ace=aclen=0; root=acnode=new_node(pos=-1,-1); } inline void extend(int c){ text[++pos]=c;need=0;remain++; while(remain){ if(!aclen)ace=pos; if(!tree[acnode].son[acedge()])tree[acnode].son[acedge()]=new_node(pos),addedge(acnode); else{ int nxt=tree[acnode].son[acedge()]; if(down(nxt))continue; if(text[tree[nxt].st+aclen]==c){aclen++;addedge(acnode);break;} int split=new_node(tree[nxt].st,tree[nxt].st+aclen); tree[acnode].son[acedge()]=split; tree[split].son[c]=new_node(pos); tree[nxt].st+=aclen; tree[split].son[text[tree[nxt].st]]=nxt; addedge(split); } remain--; if(acnode==root&&aclen)aclen--,ace=pos-remain+1; else acnode=tree[acnode].lk?tree[acnode].lk:root; } } void dfs(int x,int sum){ sum+=tree[x].len(); if(tree[x].en==inf&&pos-sum+1<=n)size[x]=1; for(int i=0;i<S;i++)if(tree[x].son[i]){ int j=tree[x].son[i]; dfs(j,sum),size[x]+=size[j]; } if(size[x])a[size[x]]=(a[size[x]]+tree[x].len())%P; } inline int pow(int a,int b){int t=1;for(;b;b>>=1,a=1LL*a*a%P)if(b&1)t=1LL*t*a%P;return t;} inline void NTT(int*a,int n,int t){ for(int i=1,j=0;i<n-1;i++){ for(int s=n;j^=s>>=1,~j&s;); if(i<j){int k=a[i];a[i]=a[j];a[j]=k;} } for(int d=0;(1<<d)<n;d++){ int m=1<<d,m2=m<<1,_w=t==1?g[d]:ng[d]; for(int i=0;i<n;i+=m2)for(int w=1,j=0;j<m;j++){ int&A=a[i+j+m],&B=a[i+j],t=1LL*w*A%P; A=B-t;if(A<0)A+=P; B=B+t;if(B>=P)B-=P; w=1LL*w*_w%P; } } if(t==-1)for(int i=0,j=inv[n];i<n;i++)a[i]=1LL*a[i]*j%P; } void getinv(int*a,int*b,int n){ if(n==1){b[0]=pow(a[0],P-2);return;} getinv(a,b,n>>1); int k=n<<1,i; for(i=0;i<n;i++)tmp[i]=a[i]; for(i=n;i<k;i++)tmp[i]=b[i]=0; NTT(tmp,k,1),NTT(b,k,1); for(i=0;i<k;i++){ b[i]=(ll)b[i]*(2-(ll)tmp[i]*b[i]%P)%P; if(b[i]<0)b[i]+=P; } NTT(b,k,-1); for(i=n;i<k;i++)b[i]=0; } inline void getln(int*a,int*b,int n){ getinv(a,tmp2,n); int k=n<<1,i; for(i=0;i<n-1;i++)b[i]=(ll)a[i+1]*(i+1)%P; for(i=n-1;i<k;i++)b[i]=0; NTT(b,k,1),NTT(tmp2,k,1); for(i=0;i<k;i++)b[i]=(ll)b[i]*tmp2[i]%P; NTT(b,k,-1); for(i=n-1;i;i--)b[i]=(ll)b[i-1]*inv[i]%P;b[0]=0; } void getexp(int*a,int*b,int n){ if(n==1){b[0]=1;return;} getexp(a,b,n>>1); getln(b,tmp,n); int k=n<<1,i; for(i=0;i<n;i++){tmp[i]=a[i]-tmp[i];if(tmp[i]<0)tmp[i]+=P;} if((++tmp[0])==P)tmp[0]=0; for(i=n;i<k;i++)tmp[i]=b[i]=0; NTT(tmp,k,1),NTT(b,k,1); for(i=0;i<k;i++)b[i]=(ll)b[i]*tmp[i]%P; NTT(b,k,-1); for(i=n;i<k;i++)b[i]=0; } int main(){ scanf("%d%d%s",&m,&x,s+1); if(m>x)return puts("0"),0; n=std::strlen(s+1); for(i=1;i<=n;extend(s[i++]-'a'));extend(26); pos--,dfs(root,0); for(i=0;i<=x;i++)a[i]=1LL*a[i]*i%P; C=a[1],j=pow(C,P-2); for(i=0;i<x;i++)a[i]=1LL*a[i+1]*j%P; for(i=x;i<k;i++)a[i]=0; for(g[K]=pow(G,(P-1)/N),ng[K]=pow(g[K],P-2),i=K-1;~i;i--)g[i]=(ll)g[i+1]*g[i+1]%P,ng[i]=(ll)ng[i+1]*ng[i+1]%P; for(inv[1]=1,i=2;i<=N;i++)inv[i]=(ll)(P-inv[P%i])*(P/i)%P;inv2=inv[2]; for(k=1;k<=x;k<<=1); getln(a,b,k); for(i=0;i<k;i++)b[i]=1LL*b[i]*m%P; getexp(b,a,k); return printf("%d",1LL*a[x-m]*pow(C,m)%P),0; }