某次钟神讲课后来补了这道题。
Description
简化题意:给定一个 (n),要求选出 (k) 个组合数 (C_a^b),是他们的和最大。其中 (a,b) 必须满足 (0 le a le b le n) 。
Solution
先考虑如何找到最大的 (k) 个组合数。
我们知道组合数的递推式是 (C_n^m = C_{n-1}^{m-1} + C_{n-1}^{m}),并且这个式子就是杨辉三角的形式。
尝试把前几行写出来。
1
1 1
1 2 1
1 3 3 1
1 4 6 4 1
1 5 10 10 5 1
1 6 15 20 15 6 1
观察一下发现对于第 (i) 行,第 (frac{i}{2}) 个数是最大的;对于每一列,越靠下越大。并且最大的是 (C_n^{frac{n}{2}})
那么考虑用广搜的思想,先把最大的 (C_n^{frac{n}{2}}) 加进去,然后不断向周围三个方向扩展,取出前 (k) 大即可。
扩展的时候注意判断是否在界内。
注意判断该位置是否入队过,这里使用 map 标记判断。
但是我们忽视一个问题:我们并不能比较 (C_{a}^{b}) 的大小。因为我们求不出来,取模的话就会使得大小无法比较。如何解决?
根据钟神的思路想到高中的一个知识点:
[log (x imes y) = log x + log y
]
[log (frac{x}{y}) = log x - log y
]
那么我们的 (C_{a}^{b}) 是不是也可以表示了?
设 (Log_i = sum_{j=1}^{i} log j),有:
[log C_a^b = Log_a - Log_b - Log_{a-b}
]
众所周知, (f(x) = log x) 是单调递增函数,所以加入优队时比较 (log C_a^b) 的大小就好了。
Code
/*
Work by: Suzt_ilymics
Problem: P4370 [Code+#4]组合数问题2
Knowledge: 优先队列,log部分知识
Time: O(能过)
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<cmath>
#include<map>
#define LL long long
#define orz cout<<"lkp AK IOI!"<<endl
using namespace std;
const int MAXN = 1e6+5;
const int INF = 1e9+7;
const int mod = 1e9+7;
int dx[] = {0, 0, -1, 0};
int dy[] = {0, -1, 0, 1};
struct node {
int n, m; double val;
bool operator < (const node &b) const { return val < b.val; }
};
int n, k; LL ans = 0;
double Log[MAXN]; // 开 double 防止精度问题
int fac[MAXN], inv[MAXN];
priority_queue<node> q;
map<int, bool> Map[MAXN];
int read(){
int s = 0, f = 0;
char ch = getchar();
while(!isdigit(ch)) f |= (ch == '-'), ch = getchar();
while(isdigit(ch)) s = (s << 1) + (s << 3) + ch - '0' , ch = getchar();
return f ? -s : s;
}
bool Check(int x, int y) { return x < 0 || y < 0 || y > x; }
LL calc(int n, int m) { return 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod; }
void Init(int limit) {
fac[0] = 1, inv[0] = 1;
fac[1] = 1, inv[1] = 1;
for(int i = 1; i <= limit; ++i) Log[i] = Log[i - 1] + log(i); //预处理 Log
for(int i = 2; i <= limit; ++i) fac[i] = 1ll * fac[i - 1] * i % mod, inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod; // 预处理阶乘
for(int i = 2; i <= limit; ++i) inv[i] = 1ll * inv[i - 1] * inv[i] % mod; // 预处理阶乘的逆元
}
void Solve() {
q.push((node){n, n/2, Log[n] - Log[n/2] - Log[n - n/2]}); // 把最大的点加进去
Map[n][n/2] = true;
for(int i = 1; i <= k; ++i) {
node u = q.top(); q.pop();
// cout<<u.n<<" "<<u.m<<"
";
// cout<<calc(u.n, u.m)<<"
";
ans = (ans + calc(u.n, u.m)) % mod;
for(int j = 1; j <= 3; ++j) { // 枚举三个方向
int dn = u.n + dx[j], dm = u.m + dy[j];
if(Check(dn, dm) || Map[dn][dm]) continue; // 判断是否出界及是否标记过
q.push((node){dn, dm, Log[dn] - Log[dm] - Log[dn - dm]});
Map[dn][dm] = true;
}
}
}
int main()
{
Init(1000000);
n = read(), k = read();
Solve();
printf("%lld", ans);
return 0;
}