题目描述
有(n)个结点,第(i)个结点的权值为(i)。
你需要对它们进行一些操作并维护一些信息,因此,你需要对它们建立一棵二叉搜索树。在整个操作过程中,第(i)个点需要被操作 (x_i) 次,每次你需要从根结点一路走到第 (i) 个点,耗时为经过的结点数。最小化你的总耗时。
输入格式
第一行一个整数n,第二行n个整数x1~xn。
输出格式
一行一个整数表示答案。
样例
样例输入
5
8 2 1 4 3
样例输出
35
数据范围与提示
对于10%的数据,(n leq 10)。
对于40%的数据,(n leq 300)。
对于70%的数据,(n leq 2000)。
对于100%的数据,(n leq 5000),(1 leq x_i leq 10^9) 。
提示:二叉搜索树或者是一棵空树,或者是具有下列性质的二叉树:若它的左子树不空,则左子树上所有结点的值均小于它的根结点的值;若它的右子树不空,则右子树上所有结点的值均大于它的根结点的值;它的左、右子树也分别为二叉搜索树。
简单解释
不要被题目给吓到,这题并不用BST。
因为区间内的节点的权值是单增且连续的。
所以我们考虑从区间内单独拎某一个节点出来作为根,他左边的点一定都会是他的左子树,右边的一定会是他的右子树。
从这里直接想到区间DP可能有点难,所以我们可以先从记忆化DFS的方向来考虑。
我们枚举根 (k),把整个区间分成左右两半,相当于是把 (k) 的左边和右边的节点的深度都增加了1,对于答案来说,答案会增加 (sum[k - 1] - sum[l - 1] + sum[r] - sum[k + 1] + x[k]),最后的 (x[k]) 是根自己。((sum)为 (x_i)前缀和)
化简一下就是 (sum[r] - sum[l - 1]),是不是很清新。
然后我们就可以愉快地将一个大问题分成两个子问题,再继续愉快地DFS,看起来没毛病对不对。
然而问题是在DFS的过程中我们并不能知道前面的断点依次是多少,深度也就无从得知,自然回溯的时候会出问题。如果记录一下的话就成了纯粹的爆搜。
大概是这样解释的,具体细节咱也解释不太清 (毕竟考场上我连爆搜都没想到)
所以该怎么办呢,考虑一下:区间、断点,一定会有神犇 (因为为同机房就有一个) 能联想到区间DP的四边形不等式优化...于是这题就可以用区间DP来做了。
定义一下 (f[l][r]) 为 (l) 到 (r) 的区间的最小答案。
(g[l][r]) 为 区间 ([l,r]) 的最优断点,不理解请自行学习四边形不等式优化。(只是这个人太菜不会讲而已)
于是结合上面关于DFS的思考,有了转移方程: (f[l][r] = min(f[l][r], f[l][k - 1] + f[k + 1][r] + sum[r] - sum[l-1]))
其中 (k) 为我们枚举的断点,根据四边形不等式可以判断最优点一定在 (g[l][r-1]) 到 (g[l+1][r]) 之间。
然后我们就可以愉快地DP了。
然后当你愉快地打出区间DP的板子,发现T飞了。
再然后,这是为什么呢。因为我们一般的区间DP都是第一维枚举长度,第二维枚举左端点 (l),这样一般是没问题的,复杂度也是稳稳的 (n^2),但他就是T了。
这涉及到二维数组的存储和随机访问的效率问题,感性理解一下就是二维数组在内存里是一行一行来存储的,如果我们先枚举长度再枚举端点,那么每相邻两次的 (l) 跟 (l) 、(r) 跟 (r) 是不连续的,所以在枚举时就会在一行行内存里跳来跳去,用屁股想也知道这样会慢。
但是一般只要算法的瓶颈复杂度足够了,这样小常数不算什么。然而,有一种生物叫做毒瘤出题人...
所以就有了下面代码中的写法,(l) 倒序枚举, (r) 正序枚举,既可以保证正确性 (正确性应该不需要我证吧...),又可以保证在访问内存时 (l) 固定,(r) 也是连续的,这样就只会在一行当中访问,自然会快一些。
当然别的区间DP也是可以这么写的。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5005;
char buf[1 << 20], *p1 = buf, *p2 = buf;
char getc() {
if(p1 == p2) {
p1 = buf, p2 = buf + fread(buf, 1, 1 << 20, stdin);
if(p1 == p2) return EOF;
}
return *p1++;
}
int read() {
int s = 0, w = 1;
char c = getc();
while(c < '0' || c > '9') {if(c == '-') w = -1; c = getc();}
while(c >= '0' && c <= '9') s = s * 10 + c - '0', c = getc();
return s * w;
}
int n;
long long sum[maxn], f[maxn][maxn];
int g[maxn][maxn];
int main() {
n = read();
for(int i = 1; i <= n; i++) f[i][i] = read(), sum[i] = sum[i - 1] + f[i][i], g[i][i] = i;
for(int l = n - 1; l >= 1; l--) {
for(int r = l + 1; r <= n; r++) {
f[l][r] = LLONG_MAX; //注意是LL,INT不够大
for(int k = g[l][r - 1]; k <= g[l + 1][r]; k++) {
if(f[l][r] > f[l][k - 1] + f[k + 1][r] + sum[r] - sum[l - 1]) {
f[l][r] = f[l][k - 1] + f[k + 1][r] + sum[r] - sum[l - 1];
g[l][r] = k;
}
}
}
}
printf("%lld
", f[1][n]);
return 0;
}