题目链接
https://www.nowcoder.com/acm/contest/212/F
题解
我们先考虑如果已知了数组 ({a_i}) 和 ({b_i}),如何判断其是否合法。
很显然我们可以使用网络流,具体建图如下:从源点 (s) 向每一个行对应的结点连边,容量为 (a_i);每一个行对应的结点向每一个列对应的结点连边,容量为 (1);每一个列对应的结点向汇点 (t) 连边,容量为 (b_i)。那么 ({a_i}) 与 ({b_i}) 是合法的当且仅当最大流等于 (sum b_i)(也等于 (sum a_i))。
而该网络的最大流长什么样呢?如果光从流的角度想不好入手。由于最大流等于最小割,因此我们转化一下,转化为考虑该网络的割。
现在,我们把整幅图的边分成三类,第一类为源点 (s) 向每一个行对应的结点连的边,第二类为每一个行对应的结点向每一个列对应的结点连的边,第三类为每一个列对应的结点向汇点 (t) 连的边。如果我们已知第一类边有 (i) 条被割掉,第三类边有 (j) 条被割掉,由于我们只能再割第二类边,因此显然剩下的 (n - i) 个行对应的结点与 (m - j) 个列对应的结点之间的边必须割掉,因此第二类边的割的贡献值为 ((n - i)(m - j))。由于第二类边的贡献是确定的,为了得到最小割,第一类边与第三类边的割的贡献也应尽量小,那么我们直接贪心地选取容量最小的边即可。
我们归纳一下上面的内容:我们将 ({a_i}) 与 ({b_i}) 排序,并记排序后 ({a_i}) 与 ({b_i}) 的前缀和为 (sa, sb),那么对于这样的网络图,最小割值为:$$minlimits_{0 leq i leq n, 0 leq j leq m}{sa_i + sb_j + (n - i)(m - j)}$$
这样,原问题就得到了转化,我们只需要求有多少个数组 ({a_i}) 满足 (sum a_i = sum b_i),且对于任意的 (i),都有:(forall j, sa_i + sb_j + (n - i)(m - j) geq sb_m)。
我们可以花 (O(nm)) 的时间通过上面的式子预处理得到每一个 (sa_i) 的最小值。这之后我们就可以dp了,设 (f_{i, j, k}) 表示已经考虑至 (a_i),且 (a_i) 的值不超过 (j),同时 (sa_i) 为 (k) 的方案数。转移枚举有多少个位置的值为 (a_i),乘上对应的组合数即可。注意在转移时要判断状态的合法性,即 (k) 必须不小于预处理出的 (sa_i) 的最小值。
时间复杂度为 (O(n^3m^2) = O(n^5))。
至此,和这道题有关的内容已经结束。不过上面提到的关于一类特殊的网络图的最大流求法引人思考。在这里,我们作简要的说明与归纳。
我们先再来归纳一下该类网络图的简单特征:
- 我们可以将该网络的所有结点划分为 (4) 个互不相交的集合,分别为源点 (s)、汇点 (t)、点集 (A) 与点集 (B)。这四个集合包含了该网络的所有结点。
- 源点 (s) 向点集 (A) 中的每一个结点有连边,源点 (s) 连向点集 (A) 中的第 (i) 个结点的边的容量为 (a_i)。
- 点集 (B) 中的每一个结点向汇点 (t) 有连边,点集 (B) 中的第 (i) 个结点连向汇点 (t) 的边的容量为 (b_i)。
- 点集 (A) 中的每一个结点向点集 (B) 中的每一个结点有连边,每条边的容量均为 (1)。
- 若点集 (A) 包含 (n) 个结点,点集 (B) 包含 (m) 个结点,不难发现,整个网络共包含 (n + m + 2) 个结点,(nm + n + m) 条边。
需要注意的是,上述性质中提到的连边均为单向的。
我们的任务,就是在已知 (n, m, {a_i}, {b_i}) 的情况下,快速求出该网络图从源点 (s) 到汇点 (t) 的最大流(最小割)。
首先,通过上面给出的最小割值的式子:(minlimits_{0 leq i leq n, 0 leq j leq m}{sa_i + sb_j + (n - i)(m - j)}),我们已经可以排序后在 (O(nm)) 的时间内解决该问题。但还有没有更快的呢?答案是有的。
对于一个 (i),我们要找到一个 (j),使得 (sa_i + sb_j + (n - i)(m - j)) 最小。而这个式子显然是可以使用斜率优化的。我们把式子拆开,变为:(sa_i +sb_j + nm - im - jn + ij)。当 (i) 确定时,(sa_i, nm, im) 显然都是定值,因此,我们希望找到一个 (j),使得 (sb_j - jn + ij) 最小。
设 (f_i = sb_j - jn + ij),那么有 (-ij + f_i = sb_j - jn),那么这显然是一个斜率为 (-i) 的一次函数,经过的点为 ((j, sb_j - jn))。由于从小到大依次枚举 (i) 时,斜率 (-i) 是单调的,因此我们可以先花 (O(m)) 的时间用所有的 (j) 构建出凸包之后,在 (O(n)) 的时间内使用单调栈(由于这里插入和弹出操作是分开的,因此可以不用单调队列)求出每一个 (f_i) 的最小值。这样,我们就可以排序后在 (O(n + m)) 的时间内解决该问题。
不过我的方法好像稍微有点复杂,虽然代码还是很好写......
代码
Wannafly挑战赛26-F. msc的棋盘 代码如下:
#include<bits/stdc++.h>
using namespace std;
template<typename T> inline bool checkMin(T& a, const T& b) {
return a > b ? a = b, true : false;
}
const int N = 55 + 10, mod = 1e9 + 7;
inline void add(int& x, int y) {
x += y;
if (x >= mod) {
x -= mod;
}
}
int n, m, a[N], b[N], f[N][N][N * N], binom[N][N], s[N][N][N * N];
void init(int n) {
binom[0][0] = 1;
for (register int i = 1; i <= n; ++i) {
for (register int j = 0; j <= i; ++j) {
binom[i][j] = (binom[i - 1][j] + (!j ? 0 : binom[i - 1][j - 1])) % mod;
}
}
}
int main() {
scanf("%d%d", &n, &m);
init(n);
for (register int i = 1; i <= m; ++i) {
scanf("%d", &b[i]);
}
sort(b + 1, b + 1 + m);
for (register int i = 1; i <= m; ++i) {
b[i] += b[i - 1];
}
for (register int i = 1; i <= n; ++i) {
int minres = n * m + 1;
for (register int j = 0; j <= m; ++j) {
checkMin(minres, b[j] + (n - i) * (m - j));
}
a[i] = b[m] - minres;
}
for (register int i = 0; i <= n; ++i) {
s[i][0][0] = f[i][0][0] = !a[i] ? binom[n][i] : 0;
for (register int j = 1; j <= m; ++j) {
s[i][j][0] = (s[i][j - 1][0] + f[i][j][0]) % mod;
}
}
for (register int i = 1; i <= n; ++i) {
for (register int j = 1; j <= m; ++j) {
for (register int k = a[i]; k <= b[m]; ++k) {
for (register int a = 0; a < i; ++a) {
if (k - j * (i - a) >= 0) {
add(f[i][j][k], 1ll * s[a][j - 1][k - j * (i - a)] * binom[n - a][i - a] % mod);
}
}
}
}
for (register int j = 1; j <= m; ++j) {
for (register int k = a[i]; k <= b[m]; ++k) {
s[i][j][k] = (s[i][j - 1][k] + f[i][j][k]) % mod;
}
}
}
printf("%d
", s[n][m][b[m]]);
return 0;
}
上面提到的一类特殊的网络流问题的代码如下(代码中,输入依次为 (n, m, {a_i}, {b_i}),输出为最大流的流量值):
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
template<typename T> inline bool checkMin(T& a, const T& b) {
return a > b ? a = b, true : false;
}
const int N = 1e6 + 10;
int n, m, q[N], t;
ll a[N], b[N];
// (i, b[i] - ni)
inline ll x_val(int i) {
return i;
}
inline ll y_val(int i) {
return b[i] - 1ll * n * i;
}
inline double slope(int i, int j) {
return 1.0 * (y_val(j) - y_val(i)) / (x_val(j) - x_val(i));
}
int main() {
scanf("%d%d", &n, &m);
for (register int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
}
for (register int i = 1; i <= m; ++i) {
scanf("%lld", &b[i]);
}
sort(a + 1, a + 1 + n);
sort(b + 1, b + 1 + m);
for (register int i = 1; i <= n; ++i) {
a[i] += a[i - 1];
}
for (register int i = 1; i <= m; ++i) {
b[i] += b[i - 1];
}
q[0] = 0;
for (register int i = 0; i <= m; ++i) {
for (; t > 0 && slope(q[t - 1], q[t]) >= slope(q[t - 1], i); --t);
q[++t] = i;
}
ll ans = 1e18;
for (register int i = 0; i <= n; ++i) {
for (; t && slope(q[t - 1], q[t]) >= -i; --t);
int j = q[t];
checkMin(ans, a[i] + b[j] + 1ll * (n - i) * (m - j));
}
printf("%lld
", ans);
return 0;
}