题目大意
将一棵树分为三部分,使最大部分与最小部分的差最小。
正解
O(n^2)做法,枚举两条边,并用dfs序判断是否有祖先关系。
O(nlogn)做法,考虑用权值线段树来维护,记住是绝对值,要用前驱后继查询。
仍分为两种情况:
1.有祖先关系。统计答案时取最接近(n+size)/2的。dfs时将size丢进权值线段树线段树,遍历后将其从树中删除。
2.无祖先关系。统计答案时取最接近(n-size)/2的。遍历后将size丢进权值线段树线段树,无需删除。
代码
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> using namespace std; int n,m=200000,sx,sy,bb[500005],size[200005],h[200005],fa[200005],dep[200005],cnt=0,tot=0,ans=2000000007; struct node { int next,to,from; }e[400005]; struct segment_tree { int val; }a[10000005]; void read(int &x){ char c=getchar(); for(;c<33;c=getchar()); for(x=0;47<c&&c<58;x=(x<<3)+(x<<1)+c-48,c=getchar()); } inline void add(int x,int y) { e[++cnt].to=y; e[cnt].from=x; e[cnt].next=h[x]; h[x]=cnt; } inline void dfs(int x) { size[x]=1; for(int i=h[x];i;i=e[i].next) { int y=e[i].to; if(y!=fa[x]) { fa[y]=x; dfs(y); size[x]+=size[y]; } } } inline void add(int l,int r,int k,int x,int val) { if(l==r) { a[k].val+=val; } else { int mid=(l+r)>>1; if(x<=mid)add(l,mid,k<<1,x,val); else add(mid+1,r,(k<<1)+1,x,val); a[k].val=a[k<<1].val+a[(k<<1)+1].val; } } inline int ub(int l,int r,int k,int x) { if(l==r) { return l; } else { int mid=(l+r)>>1,mx=0; if((r<=x||x>mid)&&a[(k<<1)+1].val) { mx=ub(mid+1,r,(k<<1)+1,x); } if(a[k<<1].val&&!mx)mx=ub(l,mid,k<<1,x); return mx; } return -2000000007; } inline int lb(int l,int r,int k,int x) { if(l==r) { return l; } else { int mid=(l+r)>>1,mx=0; if((l>=x||x<=mid)&&a[k<<1].val) { mx=lb(l,mid,k<<1,x); } if(a[(k<<1)+1].val&&!mx)mx=lb(mid+1,r,(k<<1)+1,x); return mx; } return 2000000007; } inline void dfs1(int x) { sx=ub(1,m,1,(n+size[x])>>1); sy=lb(1,m,1,(n+size[x])>>1); int c1=size[x],c2=sx-size[x],c3=n-sx; ans=min(ans,max(abs(c1-c2),max(abs(c1-c3),abs(c2-c3)))); c1=size[x],c2=sy-size[x],c3=n-sy; ans=min(ans,max(abs(c1-c2),max(abs(c1-c3),abs(c2-c3)))); add(1,m,1,size[x],1); for(int i=h[x];i;i=e[i].next) { int y=e[i].to; if(y!=fa[x]) { dfs1(y); } } add(1,m,1,size[x],-1); } inline void dfs2(int x) { sx=ub(1,m,1,(n-size[x])>>1); sy=lb(1,m,1,(n-size[x])>>1); int c3=n-size[x]-sx; ans=min(ans,max(abs(size[x]-sx),max(abs(size[x]-c3),abs(sx-c3)))); c3=n-size[x]-sy; ans=min(ans,max(abs(size[x]-sy),max(abs(size[x]-c3),abs(sy-c3)))); for(int i=h[x];i;i=e[i].next) { if(e[i].to!=fa[x]) { dfs2(e[i].to); } } add(1,m,1,size[x],1); } int main() { freopen("chilli.in","r",stdin); freopen("chilli.out","w",stdout); read(n); for(int i=1,x,y;i<n;i++) { read(x);read(y); add(x,y);add(y,x); } dfs(1); dfs1(1); dfs2(1); printf("%d ",ans); }