题意:你有n块木头,每块木头有一个高h和宽w,你可以把高度相同的木头合并成一块木头。你可以选择一些木头消去它们的一部分,浪费的部分是 消去部分的高度 * 木头的宽度,问把n块木头变成恰好m块木头至少要浪费多少木料?
思路:把木头从高到第排序,设dp[i][j]为前i块木头合并成了j块木头的最小花费。因为从大到小排序,所以合并后最后一块木头的高度一定是合并前的第i块木头的高度。那么,容易得出dp转移方程:dp[i][j] = min(dp[k][j - 1] + cal(k, i)),其中cal(k, i)为把第k + 1块木头到第i块木头的高度变成一样的花费。直接转移O(n * n * m),需要优化。
1:分治优化:设op[i][j]为向dp[i][j]转移的状态中最优值中最小的k,若op[i][j] <= op[i + 1][j], 那么便可以进行分治优化dp。对于此题,dp[x][j] + cal(x, i)和dp[y][j] + cal(y, j)(x < y)cal(x, i)和cal(y, i)有重合部分,所以有op[i][j] <= op[i + 1][j], 通过分治的过程可以缩小转移的范围,复杂度O(n * logn * m)。
代码:
#include <bits/stdc++.h> #define LL long long #define pll pair<LL, LL> #define INF 1e18 using namespace std; const int maxn = 5010; const int maxm = 2010; pll a[maxn]; int n, m; LL f[maxm][maxn], w[maxn], sum[maxn]; LL cal(LL l, LL r, LL h) { return sum[r] - sum[l] - h * (w[r] - w[l]); } void solve(int x, int l, int r, int opl, int opr) { if(l > r) return; int mid = (l + r) >> 1; pll ans = make_pair(INF, INF); for (int i = opl; i < mid && i <= opr; i++) { ans = min(ans, make_pair(f[x - 1][i] + cal(i, mid, a[mid].first), (LL)i)); } f[x][mid] = ans.first; LL opt = ans.second; solve(x, l, mid - 1, opl, opt); solve(x, mid + 1, r, opt, opr); } int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) { scanf("%d%d", &a[i].second, &a[i].first); } sort(a + 1, a + 1 + n); reverse(a + 1, a + 1 + n); for (int i = 1; i <= n; i++) { w[i] = w[i - 1] + a[i].second; sum[i] = sum[i - 1] + a[i].second * a[i].first; } for (int i = 1; i <= n; i++) f[1][i] = cal(0, i, a[i].first); for (int i = 2; i <= m; i++) { solve(i, 1, n, 0, n); } printf("%lld ", f[m][n]); }
思路2:斜率优化,把cal(k, i)式子列出来,用单调队列维护下凸包。场上没注意到斜率乘积会爆long long,非常可惜QAQ
#include <bits/stdc++.h> #define LL long long #define pll pair<LL, LL> using namespace std; const int maxn = 5010; const int maxm = 2010; pll a[maxn]; int n, m; LL f[maxn][maxn], w[maxn], sum[maxn]; int q[maxn][maxm], l[maxm], r[maxm]; LL cal(LL x, LL y) { return f[x][y] - sum[x]; } void update(int x, int y) { LL h = -a[x].first; while(l[y] < r[y]) { int p1 = q[y][l[y]], p2 = q[y][l[y] + 1]; __int128 t = (__int128)cal(p2, y) - cal(p1, y); __int128 t1 = (__int128)h * (w[p2] - w[p1]); if(t <= t1) { l[y]++; continue; } else { break; } } int k = q[y][l[y]]; f[x][y + 1] = f[k][y] + sum[x] - sum[k] + h * (w[x] - w[k]); while(l[y] < r[y]) { int p1 = q[y][r[y] - 1], p2 = q[y][r[y]]; __int128 t = (__int128)(cal(p2, y) - cal(p1, y)) * (w[x] - w[p2]); __int128 t1 = (__int128)(cal(x, y) - cal(p2, y)) * (w[p2] - w[p1]); if(t >= t1) { r[y]--; continue; } else { break; } } q[y][++r[y]] = x; } int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) { scanf("%d%d", &a[i].second, &a[i].first); } sort(a + 1, a + 1 + n); reverse(a + 1, a + 1 + n); for (int i = 1; i <= m; i++) { l[i] = 1, r[i] = 1; q[i][1] = 0; } for (int i = 1; i <= n; i++) { w[i] = w[i - 1] + a[i].second; sum[i] = sum[i - 1] + a[i].second * a[i].first; } for (int i = 1; i <= n; i++) { f[i][1] = sum[i] - a[i].first * w[i]; for (int j = 2; j <= m; j++) { update(i, j - 1); } } printf("%lld ", f[n][m]); }