Description
将一个长度为n的序列分为k段
使得总价值最大一段区间的价值表示为区间内不同数字的个数
(nleq 35000,kleq 50,1leq a_ileq n)
Solution
定义 (dp[i][j]) 表示前 i 个里面分 j 段的最大收益
一个显然的 dp 方程是 (dp[i][j]=max limits_{1leq p<i} dp[p][j-1]+w(p+1,i))。复杂度 (O(n^2k)),GG。
考虑优化此方程,因为是取 max,容易想到放在线段树上实现。
同时定义 (pre[a[i]]) 表示当前 (a[i]) 这个元素上一次出现的位置是哪里,如果没有出现则是 0 。
难点在于 (w) 数组如何动态快速的求出来,我们外层循环一个 (j) 表示分的段数,发现如果当前扫到 i 这个位置那么 a[i] 的贡献实际上是让 ([pre[a[i]],i]) 这段区间整体加一。可以这么理解,就是当前扫到 i,那么对于所有到 i 截至的区间 ([p,i]),a[i] 这个元素对这些区间有贡献的部分是左端点(in [pre[a[i]],i]) 里的这一段。线段树区间加就好了。也就是说,当前扫到了 i ,那么线段树的叶子节点 p 表示的就是 (w[p,i]) 的值,这也是我们用线段树的意义所在。这样就可以 (O(nlogn)) 求出 w 数组了。同时 dp 数组实时更新即可。
还有一点要注意的是方程是 (dp[p][j-1]+w(p+1,i)) ,也就是说能用来更新答案的是 节点 p 的 dp 值和 p+1 的累加值,有点麻烦,干脆把所有的 dp 值都往左挪一个就行了,也就是叶子节点 p 表示的实际上是 p+1 的值。感觉有点绕。。。
Code
#include<cstdio>
#include<cctype>
#include<cstring>
#define K 55
#define N 35005
#define min(A,B) ((A)<(B)?(A):(B))
#define max(A,B) ((A)>(B)?(A):(B))
#define swap(A,B) ((A)^=(B)^=(A)^=(B))
int n,k;
int f[N];
int val[N];
int pre[N];
int mx[N<<2];
int lazy[N<<2];
int getint(){
int x=0,f=0;char ch=getchar();
while(!isdigit(ch)) f|=ch=='-',ch=getchar();
while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return f?-x:x;
}
void build(int cur,int l,int r){
if(l==r){
mx[cur]=f[l-1];
return;
}
int mid=l+r>>1;
build(cur<<1,l,mid);
build(cur<<1|1,mid+1,r);
mx[cur]=max(mx[cur<<1],mx[cur<<1|1]);
}
void pushdown(int cur){
if(!lazy[cur]) return;
lazy[cur<<1]+=lazy[cur];
lazy[cur<<1|1]+=lazy[cur];
mx[cur<<1]+=lazy[cur];
mx[cur<<1|1]+=lazy[cur];
lazy[cur]=0;
}
void modify(int cur,int l,int r,int ql,int qr){
if(!ql or !qr or ql>qr) return;
if(ql<=l and r<=qr){
mx[cur]++;
lazy[cur]++;
return;
}
pushdown(cur);
int mid=l+r>>1;
if(ql<=mid)
modify(cur<<1,l,mid,ql,qr);
if(mid<qr)
modify(cur<<1|1,mid+1,r,ql,qr);
mx[cur]=max(mx[cur<<1],mx[cur<<1|1]);
}
int query(int cur,int l,int r,int ql,int qr){
if(ql<=l and r<=qr)
return mx[cur];
pushdown(cur);
int mid=l+r>>1,ans=0;
if(ql<=mid){
int p=query(cur<<1,l,mid,ql,qr);
ans=max(ans,p);
}
if(mid<qr){
int p=query(cur<<1|1,mid+1,r,ql,qr);
ans=max(ans,p);
}
return ans;
}
signed main(){
n=getint(),k=getint();
for(int i=1;i<=n;i++)
val[i]=getint();
for(int j=1;j<=k;j++){
memset(mx,0,sizeof mx);
memset(pre,0,sizeof pre);
memset(lazy,0,sizeof lazy);
build(1,1,n);
for(int i=1;i<=n;i++){
modify(1,1,n,pre[val[i]]+1,i);
pre[val[i]]=i;
//if(i<j) continue;
f[i]=query(1,1,n,1,i);
//printf("j=%d,i=%d,f=%d
",j,i,f[i]);
}
}
printf("%d
",f[n]);
return 0;
}