HDU 6333 莫队分块 + 逆元打表求组合数
题解
在n个苹果中取最多m个苹果,问有多少中取法,即求(sum_{i=0}^m C_n^i),但是样例个数和n的范围最高到1e5, 一个一个求铁定会超时。
假设(S(n, m) = sum_{i=0}^m C_n^i = C_n^0 + C_n^1+C_n^2+...+C_n^m),则可以得到
((1) S(n, m+1) = C_n^0 + C_n1+C_n2+...+C_n^m + C_n^{m+1}=S(n, m) + C_n^{m+1})
((2) 2 * S(n, m) = C_n^0 + C_n^0 +C_n^1+C_n^1+C_n^2+C_n^2+...+C_n^m+C_n^m =C_n^0+(C_n^0+C_n^1)+(C_n^1+C_n^2)+(C_n^2+C_n^3) + ... +(C_n^{m-1}+C_n^{m})+C_n^m=C_{n+1}^0+C_{n+1}^1 +C_{n+1}^2 +C_{n+1}^3+...+C_{n+1}^m+C_n^m(杨辉三角)=S(n+1, m))
即(S(n+1,m)=2*S(n,m)-C_n^m)
由公式(1)(2)可得
((3) S(n,m-1)=S(n,m)-C_n^m)
((4) S(n-1,m)=(S(n,m)+C_{n-1}^m) / 2)
也就是说,我们可以通过(S(n,m))和组合数(C_n^m)很容易的求出(S(n, m+1), S(n+1, m), S(n,m-1),S(n-1,m)),于是我们就可以离线后,用莫队分块算法求解,将(n)分为(sqrt(maxn))块,并且每一块内(m)以升序排列,莫队的具体细节参考百度。
还有一个问题是如何求(C_n^m),考虑组合数的求解公式(C_n^m=frac{n!}{m!(n-m)!}),最后结果对(MOD=1e9+7)取余,那么(C_n^m \% MOD= fact(n)*inv(m)*inv(n-m)\%MOD),其中(fact(x))表示x的阶乘,(inv(x))表示(x)阶乘的逆元。我们可以用(O(n))的算法打表求出(frac[])和(inv[])那么就可以在(O(1))的复杂度内求出(C_n^m)
代码
#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
#define red(i, a, b) for (int i = (a); i >= (b); --i)
#define clr( x , y ) memset(x,y,sizeof(x))
typedef long long ll;
const int maxn = 1e5 + 5;
const int MOD = 1e9 + 7;
ll l, r, m;
ll sum, ans[maxn], f[maxn + 1], inv[maxn + 1];
struct nodd
{
ll l , r , n , k;
} b[maxn];
// 拓展欧几里得
ll exgcd(ll m, ll n, ll &x, ll &y){
ll x1, y1, x0, y0;
x0 = 1; y0 = 0;
x1 = 0; y1 = 1;
ll r = (m % n + n) % n;
ll q = (m - r) / n;
x = 0; y = 1;
while(r){
x = x0 - q * x1; y = y0 - q * y1; x0 = x1 ;y0 = y1;
x1 = x; y1 = y;
m = n; n = r; r = m % n;
q = (m - r) / n;
}
return n;
}
void cal(){//阶乘及其逆元打表
f[0]=1;
for(int i=1;i<=maxn;i++){
f[i]=f[i-1]*i%MOD;
}
ll x,y;
exgcd(f[maxn],MOD,x,y);//先求出f[N]的逆元,再循环求出f[1~N-1]的逆元
inv[maxn]=(x%MOD+MOD)%MOD;
for(int i=maxn-1;i>=0;i--){
inv[i]=inv[i+1]*(i+1)%MOD;
}
}
// C(x, y);
ll c(ll x, ll y)
{
return (f[x] * inv[y] % MOD * inv[x - y]) % MOD;
}
//同一块内按r的升序进行排序
bool cmp(nodd a,nodd b) { return a.k == b.k ? a.r < b.r : a.k < b.k; }
void init()
{
cal();
/*
for (int i = 0; i < 100; ++i)
{
printf("%lld, %lld
", f[i], inv[i]);
}
*/
scanf("%lld", &m);
rep(i,1,m) scanf("%lld%lld",&b[i].l,&b[i].r);
l = sqrt(maxn);
//nodd的n属性记录其输入顺序, k属性记录其所在块数
rep(i,1,m) b[i].n = i , b[i].k = b[i].l / l;
//同一块内按r的升序进行排序
sort( b + 1 , b + m + 1 , cmp );
}
// S(n,m)=S(n,m-1)+C(n,m)
void incm(ll n, ll m)
{
//printf("c(%lld,%lld)=%lld
", c(n,m));
sum = (sum + c(n, m)) % MOD;
}
// S(n,m)=S(n,m+1)-C(n,m+1)
void decm(ll n, ll m)
{
sum = (sum - c(n, m + 1) + MOD) % MOD;
}
// S(n,m)=(S(n+1,m)+C(n,m) / 2)
void decn(ll n, ll m)
{
sum = ((sum + c(n, m)) * inv[2]) % MOD;
}
// S(n,m)=(S(n-1,m) * 2 - C(n-1, m))
void incn(ll n, ll m)
{
sum = (sum * 2 - c(n - 1, m) + MOD) % MOD;
}
void work()
{
rep(i,1,m) {
sum = 0;
l = b[i].l; r = b[i].r;
//printf("%lld, %lld
", l, r);
rep(j, 0, r) incm(l, j);// printf("%lld
", sum);
ans[ b[i].n ] = sum;
while ( i < m && b[i].k == b[i+1].k ) {
i++;
while ( l < b[i].l ) incn(++l, r);
while ( l > b[i].l ) decn(--l, r);
while ( r < b[i].r ) incm(l, ++r);
ans[ b[i].n ] = sum;
}
}
rep(i,1,m) {
printf("%lld
", ans[i]);
}
}
int main()
{
init();
work();
return 0;
}