题目描述
小Z是一个小有名气的钢琴家,最近C博士送给了小Z一架超级钢琴,小Z希望能够用这架钢琴创作出世界上最美妙的音乐。
这架超级钢琴可以弹奏出n个音符,编号为1至n。第i个音符的美妙度为Ai,其中Ai可正可负。
一个“超级和弦”由若干个编号连续的音符组成,包含的音符个数不少于L且不多于R。我们定义超级和弦的美妙度为其包含的所有音符的美妙度之和。两个超级和弦被认为是相同的,当且仅当这两个超级和弦所包含的音符集合是相同的。
小Z决定创作一首由k个超级和弦组成的乐曲,为了使得乐曲更加动听,小Z要求该乐曲由k个不同的超级和弦组成。我们定义一首乐曲的美妙度为其所包含的所有超级和弦的美妙度之和。小Z想知道他能够创作出来的乐曲美妙度最大值是多少。
输入输出格式
输入格式:
输入第一行包含四个正整数n, k, L, R。其中n为音符的个数,k为乐曲所包含的超级和弦个数,L和R分别是超级和弦所包含音符个数的下限和上限。
接下来n行,每行包含一个整数Ai,表示按编号从小到大每个音符的美妙度。
输出格式:
输出只有一个整数,表示乐曲美妙度的最大值。
输入输出样例
输入样例#1:
4 3 2 3
3
2
-6
8
输出样例#1:
11
说明
共有5种不同的超级和弦:
- 音符1 ~ 2,美妙度为3 + 2 = 5
- 音符2 ~ 3,美妙度为2 + (-6) = -4
- 音符3 ~ 4,美妙度为(-6) + 8 = 2
- 音符1 ~ 3,美妙度为3 + 2 + (-6) = -1
- 音符2 ~ 4,美妙度为2 + (-6) + 8 = 4
最优方案为:乐曲由和弦1,和弦3,和弦5组成,美妙度为5 + 2 + 4 = 11。
题解
堆+st表
我们定义一个三元组((s, l, r)) 表示以s为左端点 右端点在(l-r)这段区间内区间和最大
维护一个前缀和
贪心地想,既然(s)已经固定那么对于右端点在(l-r)这段区间内,我们要取前缀和最大的那个值,st搞一下就好了
定义(t)为(l-r)区间前缀和最大的位置
我们维护一个堆
每次取出价值最大
然后把((s, l, t-1) (s, t+1, r))放进去
注意特判一下(l=t, r=t)的情况
Code
#include<bits/stdc++.h>
#define LL long long
#define RG register
using namespace std;
inline int gi() {
int f = 1, s = 0;
char c = getchar();
while (c != '-' && (c < '0' || c > '9')) c = getchar();
if (c == '-') f = -1, c = getchar();
while (c >= '0' && c <= '9') s = s*10+c-'0', c = getchar();
return f == 1 ? s : -s;
}
const int N = 500010;
int mx[N][21], a[N], n, k, L, R;
LL sum[N];
void init() {
for (int i = 1; i <= n; i++)
mx[i][0] = i;
for (int i = 1; (1<<i) <= n; i++)
for (int j = 1; j + (1<<i) - 1 <= n; j++) {
int x = mx[j][i-1], y = mx[j+(1<<(i-1))][i-1];
mx[j][i] = sum[x] > sum[y] ? x : y;
}
return ;
}
inline int getmax(int l, int r) {
int k = log2(r-l+1), x = mx[l][k], y = mx[r-(1<<k)+1][k];
return sum[x] > sum[y] ? x : y;
}
struct node {
int s, l, r, t;
bool operator <(node z) const {
return (sum[t]-sum[s-1]) < (sum[z.t]-sum[z.s-1]);
}
};
priority_queue<node> q;
int main() {
n = gi(), k = gi(), L = gi(), R = gi();
for (int i = 1; i <= n; i++) {
a[i] = gi();
sum[i] = sum[i-1]+a[i];
}
init();
for (int i = 1; i+L-1 <= n; i++)
q.push((node) {i, i+L-1, min(n, i+R-1), getmax(i+L-1, min(n, i+R-1))});
LL ans = 0;
for (int i = 1; i <= k; i++) {
int l = q.top().l, r = q.top().r, s = q.top().s, t = q.top().t;
q.pop();
ans += sum[t]-sum[s-1];
if (t != l)
q.push((node) {s, l, t-1, getmax(l, t-1)});
if (t != r)
q.push((node) {s, t+1, r, getmax(t+1, r)});
}
printf("%lld
", ans);
return 0;
}