Solution [HNOI2008]GT考试
题目大意:给定一段长为(m)的数(S),求有多少个长为(n)的数不包含子串(S)
( ext{KMP})、计数、矩阵乘法
分析:
首先由于允许前导(0),一共有(10^n)个串。反着来,我们考虑有多少个串包含子串(S)
我们记(f(n,s))表示长为(n),后缀最长能匹配(S)长为(s)的前缀的串个数
考虑(f(n,s))会对哪些位置产生贡献,我们枚举第(n+1)个位置为(c)
如果(S[s+1]=c),那么(f(n,s))的值应当被累加到(f(n+1,s+1))上
如果(S[s+1] eq c),那么我们应当用( ext{KMP})算法不断跳( ext{fail}),找到转移位置。为了便于转移,以及优化运行时间,用类似( ext{AC})自动机补全( ext{Trie})树的方法建出转移图
设补全后的转移数组为(ch),两者可以统一
(f(n,s))会对(f(n+1,ch[s][c]))产生贡献,其中(cin[0,9])
先考虑计数,一个比较(naive)的想法是求(sum_n f(n,m)),这样会有重复计数
也就是说有可能同一个串包含子串(S)多次
不妨规定第一次包含子串(S)时计数,那么已经包含子串(S)之后,后面的所有位置都可以任取了。对于任意(s=m)的(f(n,s)),没必要将它的贡献累计到后面。
暴力算法:
求出(f(n,s)quad sin[0,m]),令(ans=ans*10+f(n,m)),对于所有(f(n,s) quad sin[0,m))进行转移,计算它对于位置(n+1)的贡献
这样是(O(n))的
可以用矩阵乘法优化
假设我们有长为(m+1)的数组(f),表示(f(n,s)quad sin[0,m]),由上分析,我们可以用(f[m])表示答案(从(0)开始),枚举(sin[0,m),cin[0,9]),把转移矩阵第(s)行第(ch[s][c])列(+1)
最后把转移矩阵第(m)行第(m)列置为(10)(第一次包含子串(S),后面有(k)位任取,答案要乘(10^k)),快速幂转移即可
#include <cstdio>
#include <cstring>
using namespace std;
const int maxm = 32;
int n,m,mod,ans,ch[maxm][10],fail[maxm];
inline int mul(int a,int b){return (1ll * a * b) % mod;}
inline int add(int a,int b){return (a + b) % mod;}
inline int sub(int a,int b){return (((a - b) % mod) + mod) % mod;}
inline int qpow(int a,int b){
int res = 1,base = a;
while(b){
if(b & 1)res = mul(res,base);
base = mul(base,base);
b >>= 1;
}
return res;
}
struct matrix{
int f[maxm][maxm];
int x,y;
void clear(){
memset(f,0,sizeof(f));
x = y = 0;
}
matrix operator * (const matrix &rhs)const{
matrix res;res.clear();
res.x = x,res.y = rhs.y;
for(int i = 0;i < x;i++)
for(int k = 0;k < y;k++)
for(int j = 0;j < rhs.y;j++)
res.f[i][j] = add(res.f[i][j],mul(f[i][k],rhs.f[k][j]));
return res;
}
}w,org;
inline matrix qpow(matrix base,int b){
matrix res;res.clear();
res.x = res.y = base.x;
for(int i = 0;i < res.x;i++)res.f[i][i] = 1;
while(b){
if(b & 1)res = res * base;
base = base * base;
b >>= 1;
}
return res;
}
inline int idx(char c){return c - '0';}
char str[maxm];
int main(){
scanf("%d %d %d",&n,&m,&mod);
scanf("%s",str + 1);
for(int u = 0;u < m;u++)
ch[u][idx(str[u + 1])] = u + 1;
for(int u = 1;u <= m;u++)
for(int c = 0;c < 10;c++)
if(ch[u][c])fail[ch[u][c]] = ch[fail[u]][c];
else ch[u][c] = ch[fail[u]][c];
w.x = w.y = m + 1;
for(int s = 0;s < m;s++)
for(int c = 0;c < 10;c++)
w.f[s][ch[s][c]]++;
w.f[m][m] = 10;
org.x = 1,org.y = m + 1;
org.f[0][0] = 1;
org = org * qpow(w,n);
ans = qpow(10,n);
ans = sub(ans,org.f[0][m]);
printf("%d
",ans);
return 0;
}