题解
令f[x][w]表示从x出发,当前已经死了w次,要再死一次的最少步数。
然而开不下。。。
或者令g[x][w]表示从x出发,走w步会在死一次时,最多会死多少次。
然而好像也开不下。。。
不过我们发现,答案是一个阶梯型分段函数。
因为其中隐藏着单调性,死的越多,那么下一次死的时候走的路程肯定是不降的。
差不多长这样。。。
所以我们可以维护这个分段函数,然后对于询问,我们把询问挂在s上,此时t在s的子树里,我们就可以处理询问了。
但是怎么维护呢?
前置知识:长链剖分
下面说的其实和这题没半毛钱关系。
长链剖分相比于重链剖分,每个点的重儿子指的不是size最大的儿子,而是maxdeep最大的子树。
长链剖分可以做这样一个性质:O(1)查找某个点的k极祖先。
怎么找呢?我们先长链剖分,然后对于每一条长链开一个vector,先令这条链长度为len,里面存的是这条链上面的len极祖先,这个是O(n)的。
然后再把这棵树倍增一下。
对于一个询问,我们先跳k的第一个二进制位,这个是O(1)的。然后我们跳的距离一定>k/2,然后有一个性质,就是跳上去的那个点所在的脸长一定>k/2。
这个比较显然,然后k极祖先所在点就在那条链里了。
那么长链剖分在这道题里有什么用呢。
观察到这个分段函数我们可以用队列存下来它的每一个拐点,然后每次从子树向父亲转移时用队列合并一下。
因为我们的队列是以deep为下标的,所以队列长度是不超过maxdeep的,所以我们可以用长链剖分来均摊复杂度。
这里还是用了DSU on tree的思想,先把长链加入队列,再把其他子树加进来,这样貌似是nlogn?
对于队列,因为是一颗树,所以我们可以采用dfs序分配内存,这样我们只需要开一个O(n)的队列了。
算贡献极麻烦。。。、
代码
#include<iostream> #include<cstdio> #include<queue> #include<cstring> #include<vector> #define N 300009 using namespace std; typedef long long ll; int n,fa[N],tot,head[N],p[N][20],ma[N][20],len[N],son[N],h[N],t[N],dfnn,top,dfn[N]; ll w[N],sum[N],deep[N],ans[N]; inline ll rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } struct edge{int n,to;}e[N]; struct node{ll dep,val;}q[N],st[N]; vector<node>vec[N]; inline void add(int u,int v){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;} void dfs(int u){ for(int i=1;(1<<i)<=deep[u];++i)p[u][i]=p[p[u][i-1]][i-1],ma[u][i]=max(ma[u][i-1],ma[p[u][i-1]][i-1]); for(int i=head[u];i;i=e[i].n){ int v=e[i].to;p[v][0]=u;ma[v][0]=w[u]; deep[v]=deep[u]+1; dfs(v); if(len[v]>len[son[u]])son[u]=v; } len[u]=len[son[u]]+1; } int getlca(int s,int t){ ll ans=0; for(int i=19;i>=0;--i)if(deep[s]-(1<<i)>deep[t]){ ans=max(ans,(ll)ma[s][i]);s=p[s][i]; } return ans; } inline void push(int u,node y){ while(h[u]<=t[u]&&y.val>=q[h[u]].val)h[u]++; if(h[u]<=t[u]&&y.dep==q[h[u]].dep)return; q[--h[u]]=y; if(h[u]!=t[u])sum[h[u]]=q[h[u]+1].dep*(q[h[u]+1].val-q[h[u]].val)+sum[h[u]+1]; else sum[h[u]]=0; } inline void merge(int x,int y){ top=0; while(h[x]<=t[x]&&q[h[x]].dep<=q[t[y]].dep)st[++top]=q[h[x]++]; while(top&&h[y]<=t[y]) if(st[top].dep>q[t[y]].dep)push(x,st[top--]); else push(x,q[t[y]--]); while(top)push(x,st[top--]); while(h[y]<=t[y])push(x,q[t[y]--]); } void solve(int u,int tt,int id){ ll val=getlca(tt,u); int l=h[u],r=t[u],as=l; while(l<=r){ int mid=(l+r)>>1; if(val<=q[mid].val){ as=mid;r=mid-1; }else l=mid+1; } int tag=val>=q[h[u]].val; if(tag)ans[id]=sum[h[u]]-sum[as]+q[h[u]].dep*q[h[u]].val+(val-q[as].val)*q[as].dep; else ans[id]=val*q[h[u]].dep; ans[id]+=deep[tt]-deep[u]-deep[u]*val; } void ddfs(int u){ dfn[u]=++dfnn; if(son[u])ddfs(son[u]),h[u]=h[son[u]],t[u]=t[son[u]]; else h[u]=dfnn+1,t[u]=dfnn; for(int i=head[u];i;i=e[i].n){ int v=e[i].to;if(v==son[u])continue; ddfs(v); merge(u,v); } for(int i=0;i<vec[u].size();++i){ int v=vec[u][i].dep,id=vec[u][i].val; solve(u,v,id); } push(u,node{deep[u],w[u]}); } int main(){ n=rd(); for(int i=1;i<=n;++i)w[i]=rd(); for(int i=2;i<=n;++i){ fa[i]=rd();add(fa[i],i); } dfs(1); int qu=rd();int ss,tt; for(int i=1;i<=qu;++i){ ss=rd();tt=rd(); vec[ss].push_back(node{tt,i}); } ddfs(1); for(int i=1;i<=qu;++i)printf("%lld ",ans[i]); return 0; }