Description
Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。
Input
一行三个数,n,m,p。
1<=n<=10^9,1<=m<=2×10^7,1<=p<=100
Output
一行一个数,满足Alice的要求的序列数量,答案对20170408取模。
Sample Input
3 5 3
Sample Output
33
分析:
至少有一个是质数=所有-没有质数
记f[i][j]为只考虑i个数,前i个数的和在模p意义下为j的方案数
f[i+1][k]+=f[i][j]*num ((j+x)%p=k,符合这个条件的数的个数是num)
考虑矩阵加速
观察矩阵的特点:
% | f[i][0] | f[i][1] | f[i][2] | f[i][3] |
---|---|---|---|---|
f[i-1][0] | a | b | c | d |
f[i-1][1] | d | a | b | c |
f[i-1][2] | c | d | a | b |
f[i-1][3] | b | c | d | a |
p=4 | ||||
a:%p=0的数的个数 | ||||
b:%p=1的数的个数 | ||||
c:%p=2的数的个数 | ||||
d:%p=3的数的个数 |
这种矩阵称作循环矩阵,
循环矩阵的乘积还是循环矩阵,所以做矩阵乘法时候只需算第一行,
然后按循环矩阵性质填出其他行即可
看了一下网上的程序
f和矩阵的初始化竟然O(mm)即可,震惊(ΩДΩ)
不太理解为什么矩阵的初始化要写成:
m.m[0][(-i%p+p)%p]++;
tip
注意矩阵的下标是0~p-1
1不是素数
最后的答案:
ans=(ans+f[i]*an.m[0][i]%mod)%mod;
开ll
最后的答案:(f1-f2+mod)%mod //+mod
这里写代码片
#include<cstdio>
#include<cstring>
#include<iostream>
#define ll long long
using namespace std;
const ll mod=20170408;
int n,mm,p;
int tot=0,f[101];
int sshu[20000010];
bool no[20000010];
struct node{
ll m[101][101];
node operator *(const node &a) const
{
node ans;
for (int j=0;j<p;j++) //只计算第一行
{
ans.m[0][j]=0;
for (int k=0;k<p;k++)
ans.m[0][j]=(ans.m[0][j]+m[0][k]*a.m[k][j]%mod)%mod;
}
for (int i=1;i<p;i++)
for (int j=0;j<p;j++)
{
int t=j-1;
if (t==-1) t=p-1;
ans.m[i][j]=ans.m[i-1][t];
}
return ans;
}
void clear()
{
memset(m,0,sizeof(m));
}
node KSM(ll pp)
{
pp--;
node tt=(* this);
node a=(* this);
while (pp)
{
if (pp&1)
tt=tt*a;
a=a*a;
pp>>=1;
}
return tt;
}
};
node m;
void cl() //求素数
{
memset(no,0,sizeof(no));
no[1]=1; ///
for (int i=2;i<=mm;i++)
{
if (!no[i])
sshu[++tot]=i;
for (int j=1;i*sshu[j]<=mm&&j<=tot;j++)
{
no[i*sshu[j]]=1;
if (i%sshu[j]==0) break;
}
}
}
ll solve1()
{
for (int i=1;i<=mm;i++) f[i%p]++; //
for (int i=1;i<=mm;i++) m.m[0][(-i%p+p)%p]++; //
for (int i=1;i<p;i++)
for (int j=0;j<p;j++)
{
int t=j-1;
if (t==-1) t=p-1;
m.m[i][j]=m.m[i-1][t];
}
node an=m.KSM(n-1);
ll ans=0;
for (int i=0;i<p;i++) ans=(ans+(ll)f[i]*an.m[0][i]%mod)%mod; //
return ans;
}
ll solve2()
{
memset(f,0,sizeof(f));
for (int i=1;i<=mm;i++) if (no[i]) f[i%p]++;
m.clear();
for (int i=1;i<=mm;i++) if (no[i]) m.m[0][(-i%p+p)%p]++;
for (int i=1;i<p;i++)
for (int j=0;j<p;j++)
{
int t=j-1;
if (t==-1) t=p-1;
m.m[i][j]=m.m[i-1][t];
}
node an=m.KSM(n-1);
ll ans=0;
for (int i=0;i<p;i++) ans=(ans+(ll)f[i]*an.m[0][i]%mod)%mod; //
return ans;
}
int main()
{
scanf("%d%d%d",&n,&mm,&p);
cl();
ll f1=solve1();
ll f2=solve2();
printf("%lld",(ll)(f1-f2+mod)%mod); //+mod
return 0;
}