二叉搜索树 题解
题意
(~~~~) 给出一个排列,并允许将 ([l,r]) 自行重排,求在这之后将所有数依次加入一棵 BST 的最小总深度。
(~~~~) (1leq nleq 10^5,1leq r-l+1leq 200)
题解
(~~~~) 首先大喊三声:“Reanap yyds!”
(~~~~) 首先有一种朴素的方法,即枚举 ([l,r]) 内的排列,然后依次插入二叉搜索树。
(~~~~) 考虑一种插入二叉搜索树的办法,首先用 set
记录一下已经插入的所有数,若新加入的某个数 (w) 在 (x) 和 (y) 之间,则这个数要么是 (x) 的右儿子,要么是 (y) 的左儿子,同时其左儿子的值 (in) ([x+1,w-1]) ,右儿子的值 (in) ([w+1,y]) 。为了方便维护,我们默认为 (x) 的儿子,则每次加入 (w) 后 (dep_w leftarrow dep_x+1) ,同时 (dep_x leftarrow dep_x+1) 以保证之后加入的正确性,这样可以对任何已知序列做到 (mathcal{O(nlog n)}) 求出答案。
(~~~~) 现在考虑重排的问题,在重排之前 ([1,l-1]) 已经插入,同时分割出了若干区间。同时每个区间都是与其他区间相互独立的。所以我们现在考虑怎么在每个区间插入值使得当前及之后插入的代价最小。
(~~~~) 这里我们使用代价提前计算的技巧,即对于一次在 ([x,y]) 之间插入 (w) ,则将 ([x,w-1]) 和 ([w+1,y]) 的深度提前 (+1) ,这样的话我们可以用一个区间DP来计算每个区间。
(~~~~) 定义 (dp_{l,r}) ,表示在某个小区间的 ([l,r]) 插入后的最小代价。则枚举转移点可以得到 (dp_{l,r}=max_{i=l}^r dp_{l,i-1}+dp_{i+1,r}+r-l) ,同时记录转移点以还原重排后的序列再插入即可。
代码
查看代码
#include <set>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
int n;
set<int>S;
set<int>::iterator it;
int arr[100005];
int dep[100005];
vector <int> V[100005];
ll Solve()
{
S.clear();S.insert(0);
ll ret=0;
for(int i=1;i<=n;i++) dep[i]=0;dep[0]=0;
for(int i=1;i<=n;i++)
{
int x=arr[i];
it=prev(S.lower_bound(x));
dep[x]=dep[*it]+1;dep[*it]++;ret+=dep[x];S.insert(x);
}
return ret;
}
int tot=0,cnt=0;
int dp[205][205],P[205],from[205][205],ord[205],nxt[100005];
void Rev(int l,int r)
{
if(l>r) return;
if(l==r)
{
ord[++tot]=P[l];
return;
}
ord[++tot]=P[from[l][r]];
Rev(l,from[l][r]-1); Rev(from[l][r]+1,r);
}
void DP(int R)
{
memset(dp,0,sizeof(dp));
for(int len=1;len<=R;len++)
{
for(int l=1;l+len-1<=R;l++)
{
int r=l+len-1;dp[l][r]=1e9;
for(int x=l;x<=r;x++)
{
if(dp[l][x-1]+dp[x+1][r]+(P[r+1]-1)-P[l-1]<dp[l][r])
{
dp[l][r]=dp[l][x-1]+dp[x+1][r]+(P[r+1]-1)-P[l-1];
from[l][r]=x;
}
}
}
}
Rev(1,R);
}
bool Beg[100005];
int l,r;
void Pre()
{
Beg[0]=true;S.insert(0);
for(int i=1;i<l;i++)
{
int x=arr[i];
it=prev(S.lower_bound(x));
Beg[x]=true;S.insert(x);
}
for(int i=l;i<=r;i++)
{
int x=arr[i];
it=prev(S.lower_bound(x));
V[*it].push_back(x);it++;
if(it==S.end()) nxt[x]=n+1;
else nxt[x]=*it;
}
for(int i=0;i<=n;i++)
{
if(Beg[i]&&!V[i].empty())
{
sort(V[i].begin(),V[i].end());
P[0]=i;cnt=0;for(int j=0;j<(int)V[i].size();j++) P[++cnt]=V[i][j];
P[cnt+1]=nxt[V[i][0]];
DP(cnt);
}
}
}
int main() {
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&arr[i]);
scanf("%d %d",&l,&r);
Pre();
for(int i=l;i<=r;i++) arr[i]=ord[i-l+1];
printf("%lld",Solve());
return 0;
}