题目链接:戳我
题意:将一个长度为n的序列分为k段,使得总价值最大。一段区间的价值表示为区间内不同数字的个数
(n<=35000,k<=50)
开始想的转移方程是这个样子的——(dp[i][j])表示前i个,分成j组,最大收益
然后转移方程为(dp[i][j]=max(dp[i-1][j]+cur,dp[i-1][j-1]+1)),其中cur表示这个数是否在当前组中出现过,判断可以用set来搞。
但是——不对!!!!
原因是同一种最大收益可能有不同的分组方式,而不同的分组方式显然具有后效性,不能DPqwqwq
所以我们更改DP方程——
(dp[i][j]=max(dp[k][j-1]+calc(k+1,j)))
(dp[i][j])表示前i个数分成j份。
但是这个样子的话复杂度是(O(n^2k))的,显然。。。。很凉凉。
于是我们考虑一个神奇的做法——
从j开始遍历(即把j放成外层循环),那么。。。我们遍历i的时候就是从头开始依次加入数,并计算了。而每次加入一个数对于calc的计算来说,就是对它上次出现的位置到现在这个位置都+1.但是注意如果本身就是第一个出现的,那么pre也要改一改,改成自己的。(不过我的代码直接整体前移了一位。)
代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define MAXN 350010
using namespace std;
int n,k;
int a[MAXN],dp[MAXN][55],pre[MAXN],f[MAXN],pos[MAXN];
struct Node{int x,l,r,sum,tag;}t[MAXN<<2];
inline int ls(int x){return x<<1;}
inline int rs(int x){return x<<1|1;}
inline void push_up(int x){t[x].sum=max(t[ls(x)].sum,t[rs(x)].sum);}
inline void solve(int x,int k)
{
t[x].sum+=k;
t[x].tag+=k;
}
inline void build(int x,int l,int r)
{
t[x].l=l,t[x].r=r;t[x].tag=0;
if(l==r) {t[x].sum=f[l];return;}
int mid=(l+r)>>1;
build(ls(x),l,mid);
build(rs(x),mid+1,r);
push_up(x);
}
inline void push_down(int x)
{
if(t[x].tag)
{
solve(ls(x),t[x].tag);
solve(rs(x),t[x].tag);
t[x].tag=0;
}
}
inline void update(int x,int ll,int rr)
{
int l=t[x].l,r=t[x].r;
if(ll<=l&&r<=rr) {solve(x,1);return;}
int mid=(l+r)>>1;
push_down(x);
if(ll<=mid) update(ls(x),ll,rr);
if(mid<rr) update(rs(x),ll,rr);
push_up(x);
}
inline int query(int x,int ll,int rr)
{
int l=t[x].l,r=t[x].r;
if(ll<=l&&r<=rr) return t[x].sum;
int mid=(l+r)>>1,cur_ans=0;
push_down(x);
if(ll<=mid) cur_ans=max(cur_ans,query(ls(x),ll,rr));
if(mid<rr) cur_ans=max(cur_ans,query(rs(x),ll,rr));
return cur_ans;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("ce.in","r",stdin);
#endif
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
pre[i]=pos[a[i]];
pos[a[i]]=i;
}
//for(int i=1;i<=n;i++) printf("lst[%d]=%d
",i,pre[i]);
for(int j=1;j<=k;j++)
{
build(1,0,n-1);
for(int i=1;i<=n;i++)
update(1,pre[i],i-1),f[i]=query(1,0,i-1);
}
printf("%d
",f[n]);
return 0;
}