题意
给一颗(n)个结点的树,树的根为(1),你最多选择(K)个关键点(根节点必须选择),一个点(x)的最远距离定义为(x)到根节点的路径上遇到的第一个关键点的距离,树的权值定义为所有点的最远距离的最大值。
问当(Kin { 1,2,dots,n })时,树的权值的最小值的和为多少。
分析
当固定(K)时,我们可以二分答案(dist),贪心的去check,每次找深度最大的点(x)向上跳(dist)步到祖先(y),然后将(y)的子树删除,一直这样找下去,直到将整棵树删除,就能得到所需的最少的关键点的数量。
一个很显然的结论:当答案为(x)时,关键点的数量最多为(frac{n}{x})个,树为一条链的情况下关键点数量最多。
这样我们就能反过来枚举答案(x),每次check得到的关键点数量设为(k),更新(ans[k]=min(ans[k],x)),最后累加起来就得到了最终答案。
check找最深的点和删除子树操作可以用按(dfs)序建线段树来完成,找深度最大的点就是查询区间最大值,删除子树就是区间赋值,最后还要还原,我们可以提前将线段树复制一份,然后区间更新时用(vector)存下被修改的点,最后还原时将这些点赋值为初始线段树的值。
这样一次(check)的时间复杂度为关键点数量乘(logn)。
从(1sim n)枚举答案,关键点的总数量为(n+frac{n}{2}+ frac{n}{3}+dots+frac{n}{n}),约等于(nlogn),
总的时间复杂度即为(O(nlog^2n))。
Code
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<sstream>
#include<cstdio>
#include<string>
#include<vector>
#include<bitset>
#include<queue>
#include<cmath>
#include<stack>
#include<set>
#include<map>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,p<<1|1
#define pii pair<int,int>
#define lson l,mid,p<<1
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=1e9+7;
const int N=2e5+10;
const int inf=1e9;
int n;
vector<int>g[N];
int d[N],f[N][20],L[N],R[N],pos[N],tot;
int A[N<<2],B[N<<2],tag[N<<2];
int ans[N],lg[N];
vector<int>v;
int mx(int x,int y){
if(d[x]>d[y]) return x;
else return y;
}
void bd(int l,int r,int p){
tag[p]=0;
if(l==r) return A[p]=B[p]=pos[l],void();
int mid=l+r>>1;
bd(lson);bd(rson);
A[p]=B[p]=mx(A[p<<1],A[p<<1|1]);
}
void pd(int p){
v.pb(p<<1);v.pb(p<<1|1);
A[p<<1]=A[p<<1|1]=0;
tag[p<<1]=tag[p<<1|1]=1;
tag[p]=0;
}
void up(int dl,int dr,int l,int r,int p){
v.pb(p);
if(l==dl&&r==dr){
tag[p]=1;
A[p]=0;
return;
}
int mid=l+r>>1;
if(tag[p]) pd(p);
if(dr<=mid) up(dl,dr,lson);
else if(dl>mid) up(dl,dr,rson);
else up(dl,mid,lson),up(mid+1,dr,rson);
A[p]=mx(A[p<<1],A[p<<1|1]);
}
void dfs(int u,int fa){
d[u]=d[fa]+1;
f[u][0]=fa;
for(int i=1;(1<<i)<=n;i++){
f[u][i]=f[f[u][i-1]][i-1];
}
L[u]=++tot;pos[tot]=u;
for(int x:g[u]){
if(x==fa) continue;
dfs(x,u);
}
R[u]=tot;
}
int find(int x,int k){
for(int i=lg[k];i>=0;i--){
if(k>=(1<<i)){
k-=(1<<i);
x=f[x][i];
}
}
if(!x) x=1;
return x;
}
int ck(int mid){
v.clear();
int cnt=0;
while(1){
int x=A[1];
if(x==0) break;
int y=find(x,mid);
++cnt;
up(L[y],R[y],1,n,1);
}
for(int p:v){
A[p]=B[p];
tag[p]=0;
}
return cnt;
}
int main(){
//ios::sync_with_stdio(false);
//freopen("in","r",stdin);
for(int i=2;i<N;i++) lg[i]=lg[i-1]+(1<<(lg[i-1]+1)==i);
while(~scanf("%d",&n)){
tot=0;
rep(i,2,n){
int x;
scanf("%d",&x);
g[x].pb(i);
}
dfs(1,0);
bd(1,n,1);
rep(i,1,n) ans[i]=n+1;
ll Ans=0;
for(int i=n;i>=0;i--) ans[ck(i)]=i;
for(int i=2;i<=n;i++) ans[i]=min(ans[i-1],ans[i]);
for(int i=1;i<n;i++) Ans+=ans[i];
printf("%lld
",Ans);
rep(i,1,n) g[i].clear();
}
return 0;
}