题目描述
题目
这题的差分和一般的树上差分写法差好远,参考了dalao的题解还磨了好久才写出来
主要要注意的有以下几点:
1.起点s和终点t千万不要弄错(被它卡了半天的我QAQ)
2.记深度为d的起点的总数为cnt[d]:对于一条向上走的路,在起点处cnt[d]++,搜到终点的时候cnt[d]–;向下走的路,终点处cnt[d]++,起点处cnt[d]–
给这道题的细节处理跪了ORZ,磨了三天才终于A了
代码
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<vector>
using namespace std;
const int N=300010, M=N<<1;
int n, m, w[N], rt;
int ne, he[N], nq, hq[N];
struct E {int to, next;} e[M];
void build(int u, int v) {e[ne]=(E){v,he[u]}; he[u]=ne++; e[ne]=(E){u,he[v]}; he[v]=ne++;}
struct Q {int to, next, flag;} q[M];
void add(int u, int v) {q[nq]=(Q){v,hq[u],0}; hq[u]=nq++; q[nq]=(Q){u,hq[v],0}; hq[v]=nq++;}
vector< int > upS[N],upT[N],downS[N],downT[N];
//upS表示向上走的路的起点
int cntr,rlen[M],prelen[M];
int f[N],vis[N],dep[N],lca[M],S[M],T[M];
int find(int v) {return v == f[v] ? v : f[v]=find(f[v]);}
void tarjan(int u,int fa)
{
dep[u]=dep[fa]+1; vis[u]=1; f[u]=u; int v;
for(int i=he[u]; i != -1; i=e[i].next)
{
if((v=e[i].to) == fa) continue;
tarjan(v,u); f[v]=u;
}
for(int i=hq[u]; i != -1; i=q[i].next)
{
if(!vis[v=q[i].to] || q[i].flag) continue;
q[i].flag=q[i^1].flag=1;cntr++;
int m=find(v), s, t;
if(i&1) s=v,t=u;else s=u,t=v;
if(m == s)
{
S[cntr]=s;T[cntr]=t;rlen[cntr]=dep[t]-dep[s];
downS[s].push_back(cntr);downT[t].push_back(cntr);
}
else if(m == t)
{
S[cntr]=s;T[cntr]=t;rlen[cntr]=dep[s]-dep[t];
upS[s].push_back(cntr);upT[t].push_back(cntr);
}
else
{
lca[cntr]=m;
S[cntr]=s;T[cntr]=m;rlen[cntr]=dep[s]-dep[m];
upS[s].push_back(cntr);upT[m].push_back(cntr);
prelen[++cntr]=dep[s]-dep[m];
S[cntr]=m;T[cntr]=t;rlen[cntr]=dep[t]-dep[m];
downS[m].push_back(cntr);downT[t].push_back(cntr);
}
}
}
int ans[N],cnt1[M],cnt2[M];
void pushup(int u,int fa)
{
int dep1=dep[u]+w[u]+N,ori1=cnt1[dep1],dep2=dep[u]-w[u]+N,ori2=cnt2[dep2],now,v;
for(unsigned int i=0; i < upS[u].size(); i++)
now=upS[u][i],cnt1[dep[S[now]]+N]++;
for(unsigned int i=0; i < downT[u].size(); i++)
now=downT[u][i],cnt2[dep[T[now]]-rlen[now]-prelen[now]+N]++;
for(int i=he[u]; i != -1; i=e[i].next)
if((v=e[i].to) != fa)
pushup(v,u);
ans[u]=cnt1[dep1]-ori1+cnt2[dep2]-ori2;
for(unsigned int i=0; i < upT[u].size(); i++)
{
now=upT[u][i];
cnt1[dep[S[now]]+N]--;
if(lca[now] == u && dep[S[now]]+N == dep1) ans[u]--;
}
for(unsigned int i=0; i < downS[u].size(); i++)
now=downS[u][i],cnt2[dep[T[now]]-rlen[now]-prelen[now]+N]--;
}
int siz[N], mind=N;
void dfs(int u,int fa)
{
int v, minn=N, maxn=-N;siz[u]=1;
for(int i=he[u]; i != -1; i=e[i].next)
{
if((v=e[i].to) == fa) continue;
dfs(v,u); siz[u]+=siz[v];
if(minn > siz[v]) minn=siz[v];
}
if(minn == N) return ;
if(n-siz[u] < minn && fa) minn=n-siz[u];
if(maxn < n-siz[u]) maxn=n-siz[u];
if(mind > maxn-minn) mind=maxn-minn,rt=u;
}
void solve()
{
dfs(1,0);
tarjan(rt,0);
pushup(rt,0);
for(int i=1;i<=n;i++) printf("%d ",ans[i]);
}
int read(){
int out=0; char c=getchar(); while(c < '0' || c > '9') c=getchar();
while(c >= '0' && c <= '9') out=(out<<1)+(out<<3)+c-'0',c=getchar(); return out;
}
void init()
{
memset(he, -1, sizeof(he)); memset(hq, -1, sizeof(hq));
n=read(), m=read(); int u, v;
for(int i=1;i<n;i++) u=read(), v=read(), build(u,v);
for(int i=1;i<=n;i++) w[i]=read();
for(int i=1;i<=m;i++) u=read(), v=read(), add(u,v);
}
int main()
{
init();solve();
return 0;
}