每个人可以分为往上和往下两部分,我们将向上的路径和向下的命名为第一种路径和第二种路径。
询问就是求经过某个点的,出发点深度为d-w的第一种路径数量,加上出发点深度为d+w的第二种路径的数量。
我们开个数组记录每个深度的答案。对于每条路径,我们都在下面那个点加,在上面那个点减(注意加和减都是出发点的深度上),询问就变成了求子树和。
我们可以通过dfs来求答案,dfs到一个点的时候,记下之前在所求深度上的答案,加完这棵子树后,新答案与原来的差就是这个点真正的答案。
(关于第二种路径我是先延长上去再减掉上面一段,为此还从根节点往上加了n个点)
以上是我去年(看了题解)写的做法,但是。。。
反正时间复杂度都是O(nlogn),是不是以d-t或d+t为下标,然后就变成链加单点求和,然后线段树合并或者dsu就行了?感觉这样无脑一点?
没写过不知道是不是对的。。。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define mxn 2000010
using namespace std;
int n,m,q,x,y,tot,head[mxn],p[mxn][21],dep[mxn],id[mxn],siz[mxn],a[mxn],ans[mxn];
int hd[mxn],nxt[mxn],cnt[mxn],sum[mxn],S[mxn],T[mxn],typ[mxn];
struct ed{int to,nxt;}edge[mxn<<1];
void addedge(int u,int v){
edge[++m]=(ed){v,head[u]};head[u]=m;
edge[++m]=(ed){u,head[v]};head[v]=m;
}
void dfs(int u,int fa,int d){
p[u][0]=fa;dep[u]=d;siz[u]=1;id[u]=++tot;int v;
for (int i=head[u];i;i=edge[i].nxt)
if ((v=edge[i].to)!=fa)
dfs(v,u,d+1),siz[u]+=siz[v];
}
int lca(int x,int y){
if (dep[x]<dep[y]) {int t=x;x=y;y=t;}
int k=log2(dep[x]);
for (int i=k;i>=0;--i)
if (dep[x]-(1<<i)>=dep[y]) x=p[x][i];
if (x==y) return x;
for (int i=k;i>=0;--i)
if (p[x][i]!=p[y][i]) x=p[x][i],y=p[y][i];
return p[x][0];
}
void dfs1(int u,int fa){//up
int num=cnt[dep[u]+a[u]],v;
cnt[dep[u]]+=sum[u];
for (int i=head[u];i;i=edge[i].nxt)
if ((v=edge[i].to)!=fa) dfs1(v,u);
ans[u]+=cnt[dep[u]+a[u]]-num;
for (int j=hd[u];j;j=nxt[j]) --cnt[dep[S[j]]];
}
void dfs2(int u,int fa){//down add
int num=cnt[dep[u]-a[u]],v;
for (int j=hd[u];j;j=nxt[j]) ++cnt[dep[S[j]]];
for (int i=head[u];i;i=edge[i].nxt)
if ((v=edge[i].to)!=fa) dfs2(v,u);
ans[u]+=cnt[dep[u]-a[u]]-num;
cnt[dep[u]]-=sum[u];
}
void dfs3(int u,int fa){//down minus
int num=cnt[dep[u]-a[u]],v;
for (int j=hd[u];j;j=nxt[j]) ++cnt[dep[S[j]]];
for (int i=head[u];i;i=edge[i].nxt)
if ((v=edge[i].to)!=fa) dfs3(v,u);
ans[u]-=cnt[dep[u]-a[u]]-num;
cnt[dep[u]]-=sum[u];
}
int main()
{
scanf("%d%d",&n,&q);
for (int i=1;i<n;++i)
scanf("%d%d",&x,&y),addedge(x,y);
addedge(1,n+1);
for (int i=1;i<n;++i)
addedge(n+i,n+i+1);
dfs(n*2,n*2,1);x=log2(n)+1;m=0;
for (int j=1;j<=x;++j)
for (int i=1;i<=n*2;++i)
p[i][j]=p[p[i][j-1]][j-1];
for (int i=1;i<=n;++i) scanf("%d",&a[i]);
for (int i=1;i<=q;++i){
scanf("%d%d",&x,&y);
int z=lca(x,y),d=dep[z]*2-dep[x],k=log2(dep[x]-dep[z]+1),w=z;
for (int j=k;j>=0;--j)
if (dep[w]-(1<<j)>=d) w=p[w][j];
S[++m]=x,T[m]=z,typ[m]=1;
S[++m]=w,T[m]=y,typ[m]=2;
S[++m]=w,T[m]=z,typ[m]=3;
}
for (int i=1;i<=m;++i)
if (typ[i]==1)
nxt[i]=hd[T[i]],hd[T[i]]=i,++sum[S[i]];
dfs1(n*2,n*2);
memset(cnt,0,sizeof(cnt));
memset(sum,0,sizeof(sum));
memset(hd,0,sizeof(hd));
for (int i=1;i<=m;++i)
if (typ[i]==2)
nxt[i]=hd[T[i]],hd[T[i]]=i,++sum[S[i]];
dfs2(n*2,n*2);
memset(cnt,0,sizeof(cnt));
memset(sum,0,sizeof(sum));
memset(hd,0,sizeof(hd));
for (int i=1;i<=m;++i)
if (typ[i]==3)
nxt[i]=hd[T[i]],hd[T[i]]=i,++sum[S[i]];
dfs3(n*2,n*2);
for (int i=1;i<=n;++i)
printf("%d ",ans[i]);
puts("");
return 0;
}