解决的问题:
(C_n^mpmod{p}),(p)不是质数。
对(p)分解质数,假设为(p_1^{c_1}*p_2^{c_2}*...*p_k^{c_k})。
对每个(p_i^{c_i})求出在模意义下的(C_n^m),设为(a_i),我们的问题就变为:
(egin{cases}xequiv a_1pmod{p_1^{c_1}}\ xequiv a_2pmod{p_2^{c_2}}\...\xequiv a_kpmod{p_k^{c_k}}end{cases})
求这个(x)即为答案,因为模数互质,显然可以用CRT求。
于是考虑求(C_n^mpmod{p_i^{c_i}})。
发现不能用逆元求(n!),因为不一定存在逆元,(x)存在逆元的条件为(gcd(x,p)=1)。
现在问题就是求(n!)在模(p_i^{c_i})的意义下的逆元,即求(n!)模(p_i^{c_i})的值,考虑既然模数只有一个质因子,我们将(n!)中的质因子(p_i)提出,剩下的数必定与(p_i^{c_i})互质,exgcd求逆元即可。
以下引用自
以(22!mod 3^2)为例:
按照(3^2)分段:((1*2*3*4*5*6*7*8*9)*(10*11*12*13*14*15*16*17*18)*(19*20*21*22))
将3提出后为((3^6*(1*2*3*4*5*6*7))*(1*2*4*5*7*8)*(10*11*13*14*16*17)*(19*20*22))
观察发现前(lfloorfrac{n}{p_i^{c_i}} floor)模意义下相同,求一组时候快速幂即可。((19*20*22))直接暴力算。((1*2*3*4*5*6*7))递归即可。
提出的(3)因子的数目可以直接算,为(sumlimits_{p^i<=n}lfloorfrac{n}{p^i} floor),证明见lyd蓝书。
于是求完了。
code:
#include<bits/stdc++.h>
using namespace std;
const int maxp=1000010;
typedef long long ll;
ll n,m,mod;
ll fac[maxp];
inline ll power(ll x,ll k,ll mod)
{
ll res=1%mod;
while(k)
{
if(k&1)res=res*x%mod;
x=x*x%mod;k>>=1;
}
return res;
}
ll calc(ll x,ll p,ll mod)
{
if(x<=1)return 1;
ll res=1;
if(x>=mod)res=power(fac[mod-1],x/mod,mod);
if(x%mod)res=res*fac[x%mod]%mod;
return res*calc(x/p,p,mod)%mod;
}
void exgcd(ll a,ll b,ll& x,ll& y)
{
if(!b){x=1,y=0;return;}
exgcd(b,a%b,x,y);
ll z=x;x=y,y=z-(a/b)*y;
}
inline ll inv(ll a,ll mod)
{
if(!a)return 0;
ll x,y;exgcd(a,mod,x,y);
x=(x%mod+mod)%mod;
return x;
}
inline ll C(ll n,ll m,ll p,ll mod)
{
if(m>n)return 0;
fac[0]=1;
for(ll i=1;i<mod;i++)fac[i]=fac[i-1]*((i%p)?i:1)%mod;
ll a=calc(n,p,mod),b=inv(calc(m,p,mod),mod),c=inv(calc(n-m,p,mod),mod);
ll res=a*b%mod*c%mod;
int cnt=0;
for(ll i=p;i<=n;i*=p)cnt+=n/i;
for(ll i=p;i<=m;i*=p)cnt-=m/i;
for(ll i=p;i<=n-m;i*=p)cnt-=(n-m)/i;
//cerr<<cnt<<endl;
return res*power(p,cnt,mod)%mod;
}
int main()
{
//freopen("test.in","r",stdin);
//freopen("test.out","w",stdout);
scanf("%lld%lld%lld",&n,&m,&mod);
ll tmp=mod,ans=0;
for(ll i=2;i*i<=tmp;i++)
{
if(tmp%i)continue;
ll now=1;
while(tmp%i==0)now*=i,tmp/=i;
ans=(ans+C(n,m,i,now)*(mod/now)%mod*inv(mod/now,now)%mod)%mod;
}
if(tmp>1)ans=(ans+C(n,m,tmp,tmp)*(mod/tmp)%mod*inv(mod/tmp,tmp)%mod)%mod;
printf("%lld",(ans%mod+mod)%mod);
return 0;
}