I.V.[FJOI2018]领导集团问题
这题的难点主要是在状态的设计上。
首先,一个naive的想法是设 \(f_i\) 表示节点 \(i\) 子树中,强制节点 \(i\) 选择的最优答案,然后使用线段树合并转移。
但是这样在合并不同子树时会出大问题。于是我们不得不更换状态。
于是我们设 \(f_{i,j}\) 表示 \(i\) 子树中只选择不小于 \(j\) 的点的最优答案。
则显然,\(i\) 不选时的 \(f\) 可以直接通过所有儿子的 \(f\) 求和得到。
而 \(i\) 选择的时候,就是 \(f_{i,a_i}\) 增加 \(1\)。
于是我们的操作目前便是两种:线段树合并,以及单点加一。
但是,我们发现,在单点加一后,按照DP数组的意义,所有 \(j<a_i\) 的 \(f_{i,j}\) 都要与 \(f_{i,a_i}\) 取 \(\max\)。
于是我们操作一共三种:线段树合并区间和,单点加,区间取 \(\max\)。
但是线段树合并要适配区间操作就会很麻烦。有没有更好的方法?
我们发现,\(f_i\) 数组是不增的。故我们差分后,区间取 \(\max\) 操作就变成了一头一尾的单点修改。
尾部的单点修改位置很明确,就是 \(a_i\);但是头部的单点修改,就需要在线段树上二分出第一个 \(>f_{i,a_i}\) 的位置,然后再修改。
我们总结一下,需要支持:线段树合并,单点修改,树上二分,后缀求和。是常规操作,不说了。
时间复杂度 \(O(n\log n)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
int n,a[200100],m,cnt,f[200100],rt[200100];
vector<int>v[200100],u;
#define mid ((l+r)>>1)
struct node{int lson,rson,sum;}seg[6401000];
void pushup(int x){seg[x].sum=seg[seg[x].lson].sum+seg[seg[x].rson].sum;}
void modify(int &x,int l,int r,int P,int val){
if(l>P||r<P)return;if(!x)x=++cnt;
if(l==r){seg[x].sum=val;return;}
modify(seg[x].lson,l,mid,P,val),modify(seg[x].rson,mid+1,r,P,val),pushup(x);
}
void merge(int &x,int y,int l,int r){
if(!x){x=y;return;}if(!y)return;
seg[x].sum+=seg[y].sum;
if(l!=r)merge(seg[x].lson,seg[y].lson,l,mid),merge(seg[x].rson,seg[y].rson,mid+1,r);
}
int query(int x,int l,int r,int P){
if(r<P)return 0;
if(l>=P)return seg[x].sum;
return query(seg[x].lson,l,mid,P)+query(seg[x].rson,mid+1,r,P);
}
int pos(int x,int l,int r,int k){
if(l==r)return l;
if(k>=seg[seg[x].rson].sum)return pos(seg[x].lson,l,mid,k-seg[seg[x].rson].sum);
else return pos(seg[x].rson,mid+1,r,k);
}
void erase(int &x,int l,int r,int L,int R){
if(l>R||r<L)return;
if(L<=l&&r<=R){x=0;return;}
erase(seg[x].lson,l,mid,L,R),erase(seg[x].rson,mid+1,r,L,R),pushup(x);
}
void iterate(int x,int l,int r){
if(!x)return;
printf("%d:[%d,%d]:%d\n",x,l,r,seg[x].sum);
if(l!=r)iterate(seg[x].lson,l,mid),iterate(seg[x].rson,mid+1,r);
}
void dfs(int x){
for(auto y:v[x])dfs(y),merge(rt[x],rt[y],1,m);
f[x]=query(rt[x],1,m,a[x])+1;
if(seg[rt[x]].sum>f[x]){
int L=pos(rt[x],1,m,f[x]);
int vl=query(rt[x],1,m,L),vr=query(rt[x],1,m,a[x]+1);
modify(rt[x],1,m,L,vl-f[x]),erase(rt[x],1,m,L+1,a[x]-1),modify(rt[x],1,m,a[x],f[x]-vr);
}else{
int vr=query(rt[x],1,m,a[x]+1);
erase(rt[x],1,m,1,a[x]-1),modify(rt[x],1,m,a[x],f[x]-vr);
}
// printf("%d:%d\n",x,f[x]),iterate(rt[x],1,m);
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++)scanf("%d",&a[i]),u.push_back(a[i]);
sort(u.begin(),u.end()),u.resize(m=unique(u.begin(),u.end())-u.begin());
for(int i=2,x;i<=n;i++)scanf("%d",&x),v[x].push_back(i);
for(int i=1;i<=n;i++)a[i]=lower_bound(u.begin(),u.end(),a[i])-u.begin()+1;
dfs(1),printf("%d\n",seg[rt[1]].sum);
return 0;
}