题意:给定一个n个数的数列,要求把原数列恰好分成k个区间,每个区间的价值等于区间内不同的数字个数,问最优划分完之后总的价值最大是多少。
N<=35000,K<=min(n,50)
首先,想到是个DP应该来说是自然的,我们令dp[i][j]表示前i个数恰好划分成j个区间的最优价值。
那么显然有dp[i][j]=max(dp[l][j-1]+num(l+1,i)),其中0<=l<i,num(L,R)表示区间[L,R]范围内不同的数字个数。
最后我们需要的就是dp[n][k]。
如果我们暴力枚举转移,然后区间内不同的数字个数用主席树来做,每次查询是O(logN)的
然后这个复杂度是O(N^2*K*logN)的。显然会TLE。
我们考虑如何优化这个转移的过程,我一开始往决策单调性那边有去考虑过,不过本身那一块不是很熟,也并没有什么想法,
在转移的式子中出现了一个比较复杂的东西,也并不是很好算出来,
然后这里关键在于询问了一个区间范围内的max,这就启发我们对每个位置的这个整体的值去进行维护,我们每个位置维护一个dp[i][j-1]+num(i,cur),cur表示当前正在考虑的位置,
当cur—>cur+1时,我们可以考虑一下哪些位置的num(i,cur)会发生变化,我们很自然地注意到,如果我们记每个数上一次出现的位置是last[i]的话,那么i属于区间[last[a[cur]]+1,cur]的num(i,cur)都会发生变化,也就是说我们的[last[a[cur]],cur-1]处的值都会增加1,也就是一个区间加的事情,我们很自然地能同样通过线段树来完成。这样整体的复杂度就变成了O(N*K*logN),就可以过了。
这道题打破了我个人的一个误区是,之前我做的数据结构加速DP,往往都是比较简单的在一个范围内直接找dp值的最大值,于是在想这道题的时候我也沿着数据结构存dp值,然后就卡住了。
应该要根据找最大值的那个属性,然后去用数据结构维护那些值。
然后还有一个经验是可以去考虑根据题意维护一些到当前位置的值,然后在转移的过程中考虑当前位置的下标移动后在原先的值的基础上会发生哪些修改。
#include<bits/stdc++.h> #define lson(p) p<<1 #define rson(p) p<<1|1 using namespace std; const int maxn=4e4; int dp[maxn+5][55]; int a[maxn+5]; int pre[maxn+5],p[maxn+5]; int n,k; struct node { int mx; int lazy; }b[4*maxn+5]; void pushup(int p) { b[p].mx=max(b[lson(p)].mx,b[rson(p)].mx); } void build(int p,int l,int r,int k) { b[p].mx=b[p].lazy=0; if (l==r) { b[p].mx=dp[l][k]; return ; } int mid=(l+r)>>1; build(lson(p),l,mid,k); build(rson(p),mid+1,r,k); pushup(p); } void pushdown(int p) { if (b[p].lazy) { b[lson(p)].mx+=b[p].lazy; b[rson(p)].mx+=b[p].lazy; b[lson(p)].lazy+=b[p].lazy; b[rson(p)].lazy+=b[p].lazy; b[p].lazy=0; } } void modify(int p,int l,int r,int L,int R,int v) { if (L<=l&&r<=R) { b[p].mx+=v; b[p].lazy+=v; return ; } int mid=(l+r)>>1; pushdown(p); if (L<=mid) modify(lson(p),l,mid,L,R,v); if (R>mid) modify(rson(p),mid+1,r,L,R,v); pushup(p); } int getMx(int p,int l,int r,int L,int R) { if (L<=l&&r<=R) return b[p].mx; int mid=(l+r)>>1; pushdown(p); int ans=0; if (L<=mid) ans=max(ans,getMx(lson(p),l,mid,L,R)); if (R>mid) ans=max(ans,getMx(rson(p),mid+1,r,L,R)); return ans; } int main() { scanf("%d%d",&n,&k); for (int i=1;i<=n;i++) scanf("%d",&a[i]); for (int i=1;i<=n;i++) { p[i]=pre[a[i]]; pre[a[i]]=i; } //for (int i=1;i<=n;i++) printf("%d %d ",i,p[i]); for (int j=1;j<=k;j++) { build(1,0,n,j-1); for (int i=1;i<=n;i++) { modify(1,0,n,p[i],i-1,1); dp[i][j]=getMx(1,0,n,0,i-1); } } printf("%d ",dp[n][k]); return 0; }