看了题解才会……
很好的想法是把整个过程看成若干 “取一点 i ,值+=w[ i ],值-=(sum w[j])”(其中 j 是 i 的孩子)的操作组成的序列。
序列有一个限制是 “孩子的操作在父亲前面” 。把序列反一下,操作变成 “取一点 i , 值-=w[ i ],值+=(sum w[j])” ,每个点就只被其父亲的位置限制了,比较好做。
用 ( x, y ) 表示一个点的操作。 x 表示操作结束的增量, y 表示过程中最大值与初始值的差。之所以是“与初始值的差”,是为了今后合并两个操作;因为初始值不确定。
每个点的初值就是 ( (-w[i]+sumw[j],sumw[j]) )。合并 ( a, b ) 、( c, d ) 之后会变成 ( a+c , max( b, a+d ) ) 。
可以贪心地确定每个点的操作处在序列的什么位置。方法就是尝试交换相邻位置。
发现:1.同时 x<0 ,先做 y 大的;
2.同时 x>=0 ,先做 y-x 小的;
3.x<0 先于 x>=0
用堆维护,每次找最优先的。如果要做这个点的时候,其父亲还没做,就把它和父亲合并在一起放入堆,表示做完父亲就做它。可以用并查集+链表维护一个整体内部的操作顺序。
注意 x , y 相同的要按 id 区分开;都是 x<0 而 y , id 相同的要按 x 的具体值区分开。否则删除堆无法正常工作。
#include<cstdio> #include<cstring> #include<algorithm> #include<queue> #define ll long long #define ls Ls[cr] #define rs Rs[cr] using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } ll Mx(ll a,ll b){return a>b?a:b;} ll Mn(ll a,ll b){return a<b?a:b;} const int N=2e5+5; int n,yf[N],w[N],fa[N],p[N],dy[N],tot; bool vis[N]; int hd[N],mn[N],xnt,to[N<<1],nxt[N<<1],tp[N],tt; ll ans[N]; int pr[N],nt[N],st[N],en[N]; struct Node{ ll x,y; int id; Node(ll x=0,ll y=0,int i=0):x(x),y(y),id(i) {} Node operator+ (const Node &b)const {return Node(x+b.x,Mx(y,x+b.y),id);} bool operator< (const Node &b)const { if(x<0&&b.x>=0)return false; if(x>=0&&b.x<0)return true; if(x<0&&b.x<0) { if(y!=b.y)return y>b.y; if(id!=b.id)return id<b.id; return x<b.x; } if(y-x!=b.y-b.x)return y-x<b.y-b.x; if(id!=b.id)return id<b.id; return x<b.x; } bool operator== (const Node &b)const { return x==b.x&&y==b.y&&id==b.id;} }a[N],ya[N]; priority_queue<Node> q,dq; namespace T{ const int M=N*20; int tot,rt[N],Ls[M],Rs[M]; Node vl[M]; int nwnd(int pr=0) { int cr=++tot; ls=Ls[pr]; rs=Rs[pr]; vl[cr]=vl[pr]; return cr; } void build(int l,int r,int &cr,int ps) { cr=nwnd(); if(l==r){vl[cr]=ya[p[l]];return;} int mid=l+r>>1; if(ps<=mid)build(l,mid,ls,ps); else build(mid+1,r,rs,ps); vl[cr]=vl[ls]+vl[rs]; } void mrg(int l,int r,int &cr,int pr) { if(!pr)return; if(!cr){cr=pr;return;}//use is ok if(l==r){vl[cr]=vl[cr]+vl[pr];return;} int mid=l+r>>1; mrg(l,mid,ls,Ls[pr]); mrg(mid+1,r,rs,Rs[pr]); vl[cr]=vl[ls]+vl[rs]; } void dfs(int cr) { build(1,n,rt[cr],dy[cr]); for(int i=hd[cr],v;i;i=nxt[i]) { dfs(v=to[i]); mrg(1,n,rt[cr],rt[v]); } ans[cr]=w[cr]+vl[rt[cr]].y; } } void add(int x,int y) {to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void frs() { while(dq.size()&&dq.top()==q.top()) q.pop(), dq.pop(); } int fnd(int a){ return fa[a]==a?a:fa[a]=fnd(fa[a]);} void mrg(int x,int y) { if(fa[y]==x)exit(0); pr[st[y]]=en[x]; nt[en[x]]=st[y]; en[x]=en[y]; dq.push(a[x]); a[x]=a[x]+a[y]; q.push(a[x]); fa[y]=x; mn[x]=Mn(mn[x],mn[y]); } void solve(int x) { int cr=st[x]; while(cr) { p[++tot]=cr;dy[cr]=tot;vis[cr]=1;cr=nt[cr]; } } int main() { int op=rdn(); n=rdn(); for(int i=2;i<=n;i++)yf[i]=rdn(),add(yf[i],i); for(int i=1;i<=n;i++)w[i]=rdn(); for(int i=n;i;i--) { a[i].x=a[i].y-w[i]; a[yf[i]].y+=w[i]; a[i].id=i; fa[i]=i; mn[i]=i; q.push(a[i]); ya[i]=a[i]; st[i]=en[i]=i; } vis[0]=1; while(q.size()) { frs(); if(!q.size())break; Node k=q.top(); q.pop(); int x=k.id; if(vis[yf[mn[x]]]){solve(x);continue;} mrg(fnd(yf[mn[x]]),x); } /*for(int i=1;i<=tot;i++)printf("%d ",p[i]);puts(""); for(int i=1;i<=tot;i++)printf("%d ",dy[i]);puts("");*/ T::dfs(1); for(int i=1;i<=n;i++)printf("%lld ",ans[i]); puts(""); return 0; }