做这道题真的是涨姿势了,一般的CDQ分治都是在序列上进行的,这次是把CDQ分治放树上跑了~
考虑一半的 CDQ 分治怎么进行:
递归处理左区间,处理左区间对右区间的影响,然后再递归处理右区间.
所以,如果是有坐标不递增的斜率优化的话就用 CDQ 分治先处理出左半部分答案,然后将处理好的左区间答案用来更新右区间.
那么,将序列问题拓展到树上后,我们也要选择一个合适的中点来保证分治层数不多,且区间大小均匀.
而树中这个"中点"就是一棵树的重心!!
即当我们处理以 $x$ 为根的子树时(分治区间),先找到重心,然后扣掉重心为根的子树(右区间),然后递归处理 $x$ 为根子树抛去重心为根子树的答案 (递归处理左区间).
递归处理完“左区间”后,计算左对右的影响,那么对重心为根子树的影响就是 $x$ 到重心这条链上所有点.
然后这一部分就不难了,分别按照影响的坐标范围排一下序,然后双指针扫一扫就行了.
处理完对于右区间的贡献后,我们再递归处理右区间:再递归处理重心的每一个儿子即可.
据说这个时间复杂度是 $O(nlog^2n)$ 的
code:
#include <bits/stdc++.h> #define N 2000006 #define ll long long #define setIO(s) freopen(s".in","r",stdin) using namespace std; ll f[N],dep[N],p[N],q[N],l[N]; int edges,root,sn,la,lb,tot,sta[N]; int hd[N],to[N],nex[N],size[N],mx[N],vis[N],A[N],Fa[N],B[N],S[N]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } bool cmp(int a,int b) { return dep[a]-l[a]>dep[b]-l[b]; } void getroot(int u) { size[u]=1,mx[u]=0; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(vis[v]) continue; getroot(v); size[u]+=size[v]; mx[u]=max(mx[u],size[v]); } mx[u]=max(mx[u],sn-size[u]); if(mx[u]<mx[root]) root=u; } void dfs(int u) { B[++lb]=u; for(int i=hd[u];i;i=nex[i]) if(!vis[to[i]]) dfs(to[i]); } double slope(int a,int b) { return (double) (f[a]-f[b])/(dep[a]-dep[b]); } void update(int x) { if(!tot) return; int l=1,r=tot,mid,ret=tot; while(l<=r) { mid=(l+r)>>1; if(slope(sta[mid],sta[mid+1])<p[x]) ret=mid,r=mid-1; else l=mid+1; } f[x]=min(f[x], f[sta[ret]]-dep[sta[ret]]*p[x]+q[x]); } void solve(int u) { int i,j,rt; root=0,sn=size[u],getroot(u),vis[rt=root]=1; if(root!=u) size[u]-=size[rt], solve(u); la=lb=tot=0; A[++la]=rt; for(i=rt;i!=u;i=Fa[i]) { if(dep[rt]-l[rt]<=dep[Fa[i]]) f[rt]=min(f[rt],f[Fa[i]]-dep[Fa[i]]*p[rt]+q[rt]); A[++la]=Fa[i]; } for(int i=hd[rt];i;i=nex[i]) if(!vis[to[i]]) dfs(to[i]); sort(B+1,B+lb+1,cmp); for(i=j=1;i<=la;++i) { while(j<=lb&&dep[A[i]]<dep[B[j]]-l[B[j]]) update(B[j++]); while(tot>1&&slope(sta[tot-1],sta[tot])<=slope(sta[tot],A[i])) --tot; sta[++tot]=A[i]; } while(j<=lb) update(B[j++]); for(i=hd[rt];i;i=nex[i]) if(!vis[to[i]]) solve(to[i]); } int main() { // setIO("input"); int i,j,n,ty; scanf("%d%d",&n,&ty); for(i=2;i<=n;++i) { scanf("%d%lld%d%lld%lld",&Fa[i],&dep[i],&p[i],&q[i],&l[i]); add(Fa[i],i); dep[i]=dep[Fa[i]]+dep[i]; q[i]+=dep[i]*p[i]; } memset(f,0x3f,sizeof(f)), f[1]=0; mx[0]=sn=n, size[1]=n, solve(1); for(i=2;i<=n;++i) printf("%lld ",f[i]); return 0; }