题:https://ac.nowcoder.com/acm/problem/13884
题意:给定n,m,k,代表n个连续的格子,要求用m种颜色去涂色,要求图上的色的种类恰好为k。(n,m<=1e9,k<=1e6)
分析:恰好的字眼,说明要转成,不多于或至少的答案去做,再用容斥去求;
假设用不超过k种颜色去涂格子(相邻不能同色)的方案为xk =k*(k-1)n-1,表示第一个格子有k种颜色可选,接下来的格子由于相邻不能同色,所以每个格子只有k-1的选择;
但当前只是不超过,若要正好为k种颜色,那么根据容斥原理,ans=xk-C(k,k-1)*xk-1+C(k,k-2)*xk-2+.....+(-1)k*x1 。这部分只要预处理出阶乘和阶乘逆元即可累加处理。
最后只要在m种颜色中取k种颜色,那么只要ans再乘上C(m,k)即可,由于m很大,但是k在合理范围内,所以利用k来进行求值即可。
C:
#include<bits/stdc++.h> using namespace std; #define pb push_back #define MP make_pair typedef long long ll; const int M=1e6+15; const int inf=0x3f3f3f3f; const ll INF=1e18; const int mod=1e9+7; ll fac[M],facinv[M]; ll ksm(ll x,ll y){ ll t=1; while(y){ if(y&1) t=(t*x)%mod; x=(x*x)%mod; y>>=1; } return t; } ll inv(ll x){ return ksm(x,mod-2)%mod; } ll C(ll x,ll y){ if(y>x||x<0||y<0) return 0; if(y==0||x==y) return 1; return fac[x]*facinv[y]%mod*facinv[x-y]%mod; } void init(){ fac[0]=1; for(int i=1;i<=M;i++) fac[i]=1ll*fac[i-1]*i%mod; facinv[M-1]=ksm(fac[M-1],mod-2); for(int i=M-2;i>=0;i--) facinv[i]=facinv[i+1]*(i+1)%mod; } int main(){ init(); int T; scanf("%d",&T); while(T--){ ll n,m,k; scanf("%lld%lld%lld",&n,&m,&k); ll ans=0; ///容斥 for(int dis=1,i=0;i<k;i++,dis=-dis){ ans=(ans+dis*C(k,k-i)*(k-i)%mod*ksm(k-i-1,n-1)%mod+mod)%mod; } ///乘上C(m,k); for(int i=0;i<k;i++) ans=ans*(m-i)%mod*inv(k-i)%mod; printf("%lld ",ans); } return 0; }
java:
import java.util.Scanner; public class Main { final static int M=(int) (1e6+6); final static int mod=(int) (1e9+7); static long fac[] = new long[M]; static long facinv[] = new long[M]; public static long ksm(long x,long y) { long t=1; while(y!=0) { if((y&1)==1) t=(t*x)%mod; x=(x*x)%mod; y>>=1; } return t; } public static long inv(long x) { return ksm(x,mod-2); } public static void init() { fac[0]=1; for(int i=1;i<M;i++) fac[i]=(long)fac[i-1]*i%mod; facinv[M-1]=ksm(fac[M-1],mod-2); for(int i=M-2;i>=0;i--) { facinv[i]=facinv[i+1]*(i+1)%mod; } } public static long C(long x,long y) { if(y>x||x<0||y<0) return 0; if(y==0||x==0) return 1; return fac[(int) x]*facinv[(int) (x-y)]%mod*facinv[(int) y]%mod; } public static void main(String[] args) { Scanner sc = new Scanner(System.in); init(); for(int T=sc.nextInt();T>0;T--) { long n=sc.nextLong(); long m=sc.nextLong(); long k=sc.nextLong(); long ans=0; for(int i=0,dis=1;i<k;i++,dis=-dis) { ans=(ans+dis*C(k,k-i)*(k-i)%mod*ksm(k-i-1,n-1)%mod+mod)%mod; } for(int i=0;i<k;i++) ans=ans*(m-i)%mod*inv(k-i)%mod; System.out.printf("%d",ans); System.out.println(); } } }