首先日常%rainboy巨佬,在24分钟的时候就过掉了这题,比赛结束后又交了几发我看不懂但比标程快了几十倍的代码,然而比赛的时候我对着那个\((n^2k)\)的\(dp\)研究了半天也找不出什么优化方法,果然还是太菜了TAT。
题意:
给一长度为\(n\)的数组,你要把该数组分割成\(k\)段,每段的\(cost\)定义为该段中所有数最后一次出现的下标减去第一次出现的下标之和,求最小的\(cost\)之和。
分析:
容易想到设\(dp[i][j]\)表示将前\(i\)个数分成\(j\)段的答案,每次在\(i\)后面补上一段\(len\)便可转移到\(dp[i+len][j+1]\),然而这样的转移是\((n^2k)\)的。现在我们只考虑某一段\([len]\)的答案,假设为\(cost\),我们在这一段数后面添一个数\(x\),那么新的这一段\([len+1]\)的答案就变成了\(cost+pos[x]-lastpos[x]\),由此可以发现,每一段的\(cost\)计算实际上就是计算该段中相同的数两两相邻坐标之差的和,例如\([2, 3, 2, 3, 2]\),计算方法为:\((3-1)+(5-3)+(4-2)=(5-1)+(4-2)=6\)。
知道了这个性质有什么用呢,我们考虑以第二维的分割段数\(j\)作为阶段,假设我们已经求出了所有阶段\(j\)的\(dp[i][j]\),现在我们想要求阶段\(j+1\)的\(dp\)值,我们枚举第一维\(i\),考虑当添加一个数\(a_i\),此时我们考虑\(a_i\)的上一个位置\(last[a_i]\),所有\(dp[last[a_i]][j]\)之前的\(dp\)值都应该加上\(a_i\)带来的贡献,也就是\(i-last[a_i]\),因为我们当前阶段\(j+1\)的\(+1\)这一段加在了\(last[a_i]\)后面,那么它的贡献就应该加在\(last[a_i]\)前面,然后当前的\(dp[i][j+1]\)就应该是\(i\)之前的所有更新过的\(dp\)值取\(min\),然后这两个操作用一个线段树就可以维护出来了,当然还要注意当\(i<j\)即数组长度小于分割段数等边界情况和一些细节。
代码:
#include <bits/stdc++.h>
using namespace std;
#define ls i << 1
#define rs i << 1 | 1
#define int long long
const int N = 35010, K = 105, inf = 1e9;
int n, k, a[N], s[N], lst[N], dp[N][K];
struct seg {
int l, r, mi, tag;
inline int mid() {
return l + r >> 1;
}
}node[N << 2];
void push(int i) {
node[i].mi = min(node[ls].mi, node[rs].mi);
}
void pull(int i) {
if (node[i].tag) {
node[ls].mi += node[i].tag;
node[rs].mi += node[i].tag;
node[ls].tag += node[i].tag;
node[rs].tag += node[i].tag;
node[i].tag = 0;
}
}
void build(int l, int r, int i) {
node[i] = {l, r};
node[i].tag = 0;
if (l == r) {
node[i].mi = s[l];
return;
}
int mid = l + r >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
push(i);
}
void modify(int l, int r, int v, int i) {
if (l <= node[i].l && r >= node[i].r) {
node[i].mi += v;
node[i].tag += v;
return;
}
pull(i);
int mid = node[i].mid();
if (l <= mid) modify(l, r, v, ls);
if (r > mid) modify(l, r, v, rs);
push(i);
}
int query(int l, int r, int i) {
if (l <= node[i].l && r >= node[i].r) {
return node[i].mi;
}
pull(i);
int mid = node[i].mid(), res = inf;
if (l <= mid) res = min(res, query(l, r, ls));
if (r > mid) res = min(res, query(l, r, rs));
return res;
}
main() {
scanf("%lld%lld",&n,&k);
map<int, int> pos;
// 求每个位置的last
for (int i = 1; i <= n; i++) {
scanf("%lld",&a[i]);
lst[i] = pos[a[i]];
pos[a[i]] = i;
}
// 初始化只有一段的情况
for (int i = 1; i <= n; i++) {
if (lst[i]) dp[i][1] = dp[i-1][1] + i - lst[i];
else dp[i][1] = dp[i-1][1];
}
// 枚举阶段
for (int nowk = 2; nowk <= k; nowk++) {
for (int i = 1; i <= n; i++) s[i] = dp[i][nowk - 1];
build(1, n, 1);
// 边界情况
for (int i = nowk - 1; i < nowk; i++) dp[i][nowk] = inf;
// 更新转移
for (int i = 1; i <= n; i++) {
if (lst[i] > 1 && nowk - 1 < lst[i])
modify(nowk - 1, lst[i] - 1, i - lst[i], 1);
dp[i][nowk] = query(nowk - 1, i, 1);
}
}
printf("%lld\n",dp[n][k]);
return 0;
}