中国剩余定理(CRT)可以说是必学的一个东西啦,主要是用来求线性同余方程组的算法。
x≡a1(mod m1) x≡a2(mod m2) ... x≡an(mod mn) 在这n个同余方程下解出x
普通的CRT只能求mi两两互质的情况,如若mi不两两互质就得用到扩展CRT。(其实感觉CRT跟扩展CRT原理差别挺大的)
CRT:https://www.cnblogs.com/zwfymqz/p/8425019.html
扩展CRT:https://www.cnblogs.com/zwfymqz/p/8425019.html
然后直接上模板:
中国剩余定理
#include<iostream> #include<cstdio> using namespace std; long long n,m[100],a[100],d; long long exgcd(long long a,long long b,long long &x,long long &y) { if (b==0) { x=1; y=0; return a; } else { long long tmp=exgcd(b,a%b,y,x); y-=x*(a/b); return tmp; } } long long work() { long long res=0,x,y; long long lcm=1; for (int i=1;i<=n;i++) lcm=lcm*m[i]; for (int i=1;i<=n;i++) { long long M=lcm/m[i]; exgcd(M,m[i],x,y); x=(x%m[i]+m[i])%m[i]; res=(res+(long long)(a[i]*x*M))%lcm; } return res; } int main() { scanf("%d",&n); for (int i=1;i<=n;i++) { scanf("%lld%lld",&m[i],&a[i]); } cout<<work(); return 0; }
扩展中国剩余定理
#include<iostream> #include<cstdio> using namespace std; const int N=1000000+10; typedef long long LL; int n; LL m[N],a[N]; LL exgcd(LL a,LL b,LL &x,LL &y) { if (b==0) { x=1; y=0; return a; } else { LL tmp=exgcd(b,a%b,y,x); y-=x*(a/b); return tmp; } } long long work() { LL lcm=m[1],X=a[1],t,x,y; for (int i=2;i<=n;i++) { LL b=(a[i]-X%m[i]+m[i])%m[i]; LL d=exgcd(lcm,m[i],x,y); //解这个方程出来t的特解x t=(b/d)*x if (b%d) return -1; t=(b/d)*x%m[i]; X=(X+t*lcm); //那么X(k)=X(k-1)+tm lcm=lcm*m[i]/d; X=(X%lcm+lcm)%lcm; } return X; } int main() { scanf("%d",&n); for (int i=1;i<=n;i++) scanf("%lld%lld",&m[i],&a[i]); cout<<work()<<endl; return 0; }
中国剩余定理的题目当然就是解同余方程组啦,这就要求你能看出来是同余方程组的模型。当然CRT可能不会单独考会结合其他知识点一起考,毕竟只有CRT也是很干瘪的啦。
题目练习:
POJ-1066
CRT裸题,容易看出是n个同余方程。
#include<iostream> #include<cstdio> using namespace std; const long long m[4]={23,28,33}; int n,a[4],d; int exgcd(int a,int b,int &x,int &y) { if (b==0) { x=1; y=0; return a; } else { int tmp=exgcd(b,a%b,y,x); y-=x*(a/b); return tmp; } } long long work() { long long res=0; int x,y; long long lcm=m[0]*m[1]*m[2]; for (int i=0;i<3;i++) { int M=lcm/m[i]; exgcd(M,m[i],x,y); x=(x%m[i]+m[i])%m[i]; res=(res+(long long)(a[i]*x*M))%lcm; } return res; } int main() { int T=0; while (scanf("%d%d%d%d",&a[0],&a[1],&a[2],&d)==4 && d!=-1) { long long tmp=work(); tmp=(tmp-d+21252)%21252; if (tmp==0) tmp=21252; printf("Case %d: the next triple peak occurs in %d days. ",++T,tmp); } return 0; }
洛谷P4777
扩展CRT测模板裸题,但是注意这题乘法有可能会超出long long。要不用__int128,要不用快速乘。
#include<bits/stdc++.h> #define LL __int128 using namespace std; const int N=1e6+10; int n; LL m[N],a[N]; LL exgcd(LL a,LL b,LL &x,LL &y) { if (b==0) { x=1; y=0; return a; } else { LL tmp=exgcd(b,a%b,y,x); y-=x*(a/b); return tmp; } } long long work() { LL lcm=m[1],X=a[1],t,x,y; for (int i=2;i<=n;i++) { LL b=(a[i]-X%m[i]+m[i])%m[i]; LL d=exgcd(lcm,m[i],x,y); //解这个方程出来t的特解x t=(b/d)*x if (b%d) return -1; t=(b/d)*x%m[i]; X=(X+t*lcm); //那么X(k)=X(k-1)+tm lcm=lcm*m[i]/d; X=(X%lcm+lcm)%lcm; } return X; } int main() { scanf("%d",&n); for (int i=1;i<=n;i++) scanf("%lld%lld",&m[i],&a[i]); cout<<work()<<endl; return 0; }
HDU-1573
这题主要是要想明白,其实线性同余方程组是有无数个解的,然后CRT求出来的就是最小解。
这里直接给出结论,CRT和拓展CRT解出来的方程通解都是 x+k*lcm(k€Z)。
所以ans=(N-x)/lcm+1。这里有一个坑点就是因为答案算的是正整数,所以最小解等于0的话就ans-1.
#include<bits/stdc++.h> using namespace std; const int N=20+10; typedef long long LL; LL n,l,m[N],a[N]; LL exgcd(LL a,LL b,LL &x,LL &y) { if (b==0) { x=1; y=0; return a; } else { LL tmp=exgcd(b,a%b,y,x); y-=x*(a/b); return tmp; } } void exCRT() { LL lcm=m[1],X=a[1],t=0,x=0,y=0; for (int i=2;i<=n;i++) { LL b=(a[i]-X%m[i]+m[i])%m[i]; LL d=exgcd(lcm,m[i],x,y); //解这个方程出来t的特解x t=(b/d)*x if (b%d) { puts("0"); return; } //无解 t=(b/d)*x%m[i]; X=(X+t*lcm); //那么X(k)=X(k-1)+tm lcm=lcm*m[i]/d; X=(X%lcm+lcm)%lcm; } if (X>l) { puts("0"); return; } int ans=(l-X)/lcm+1; if (X==0) ans--; printf("%lld ",ans); } int main() { int T; cin>>T; while (T--) { memset(m,0,sizeof(m)); memset(a,0,sizeof(a)); scanf("%lld%lld",&l,&n); for (int i=1;i<=n;i++) scanf("%lld",&m[i]); for (int i=1;i<=n;i++) scanf("%lld",&a[i]); exCRT(); } return 0; }
HDU-1951
这道题是真的好,把许多知识点结合起来考了,一定要做一做想明白。
#include<iostream> #include<cstdio> #include<algorithm> using namespace std; const int MOD=999911659; typedef long long LL; int n,q,cnt=0; LL a[4],m[4],ys[100000],jc[100000],inv[100000]; void get_ys() { cnt=0; for (int i=1;i*i<=n;i++) if (n%i==0) { ys[++cnt]=i; if (i*i!=n) ys[++cnt]=n/i; } sort(ys+1,ys+cnt+1); } LL power(LL x,LL p,LL Mod) { LL res=1; while (p) { if (p&1) res=(res*x)%Mod; p>>=1; x=(x*x)%Mod; } return res; } LL exgcd(LL a,LL b,LL &x,LL &y) { if (b==0) { x=1; y=0; return a; } else { LL tmp=exgcd(b,a%b,y,x); y-=x*(a/b); return tmp; } } LL C(LL a,LL b,LL Mod) { if(a<b) return 0; if(a==b || !b) return 1; return (jc[a]*inv[b]*inv[a-b])%Mod; } LL Lucas(LL a,LL b,LL Mod) { if (!a || !b) return 1; return C(a%Mod,b%Mod,Mod)*Lucas(a/Mod,b/Mod,Mod); } LL slove(LL Mod) { LL t=1; for (int i=1;i<=Mod;i++) { t=(t*i)%Mod; jc[i]=t; inv[i]=power(t,Mod-2,Mod); } LL res=0; for (int i=1;i<=cnt;i++) { res=(res+Lucas(n,ys[i],Mod))%Mod; } return res; } int main() { scanf("%d%d",&n,&q); if (q%MOD==0) { cout<<0; return 0; } get_ys(); a[0]=slove(2); m[0]=2; a[1]=slove(3); m[1]=3; a[2]=slove(4679); m[2]=4679; a[3]=slove(35617); m[3]=35617; LL ans=0,lcm=1,x,y,M; for (int i=0;i<=3;i++) lcm=lcm*m[i]; for (int i=0;i<=3;i++) { M=lcm/m[i]; exgcd(M,m[i],x,y); x=(x%m[i]+m[i])%m[i]; ans=(ans+(LL)(a[i]*x%lcm*M))%lcm; } cout<<power(q,ans,MOD)<<endl; return 0; }