题面
分析
上次用树链剖分写了一次..现在回头重新来看,就感觉到树上差分真的是好东西了。
这个题在树上差分里面level已经很高了,但是其实还是蛮套路的,和普通的树上差分区分度就在于每个点有个观察时间w[u]
然而比较温情的是出发时间都是确定的,从一开始就出发,且边权为1.
于是我们可以发现有两种情况可以链接这个w[u]
- 存在点x处于s到lca的路上(包括lca),则有dep[s]-dep[x]=w[x],移项得dep[s]=w[x]+dep[x]
- 存在点x处于lca到t的路上(为避免重复计算不包括lca),则有dep[s]+dep[x]-2*dep[lca]=w[x],移项得dep[s]-2*dep[lca]=w[x]-dep[x]。
稍微解释一下这两个式子,每个等式移项前的左右两边都表示s-->x的距离(因为边权为1,所以距离也等于时间)
于是,我们需要在每个s-->t的路径中找到有多少个满足条件的x即可。
为什么可以差分?因为标记影响到的仅仅是以lca为根的这棵子树,符合树上差分的答案是子树和的思想。
而我们只需要计算以某点为根的子树在累积答案前后的差值,就可以得到这个点的答案。
具体做法:当我们遇到每一组s和t时,就映射等式左边部分,打上标记,在dfs过程中查询映射等式右边部分的数量。
即通过已知的dep[s]和dep[lca]加标记,而在dfs过程中w也变为已知量,再引入w求解。
需要注意的是lca不能重复计算,一个删除标记打在lca的父亲上,另一个要打在lca上。
再次提醒,别用map,这玩意儿常数神大,读入输出优化,inline,register都救不了。上次做分块的时候就已经被map教做人了。
而且这题貌似卡常??我把map映射改成数组映射过后那个被卡T的点直接跑成600ms...
代码
- #include<bits/stdc++.h>
- using namespace std;
- #define N 300030
- #define RT register
- int n,m,cnt;
- int d[N],w[N],dep[N],ans[N],first[N],fa[N][20];
- int A[N*10],B[N*10];
- struct email
- {
- int u,v;
- int nxt;
- }e[N*4];
- struct update
- {
- int c,tag,id;
- };
- vector<update>v[N];
- template<class T>
- inline void read(T &x)
- {
- x=0;int f=1;static char c=getchar();
- while(c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();}
- while(c>='0'&&c<='9'){x=x*10+c-'0',c=getchar();}
- x*=f;
- }
- void print(int x)
- {
- if(x>9)print(x/10);
- putchar(x%10+'0');
- }
- inline void add(int u,int v)
- {
- e[++cnt].nxt=first[u];first[u]=cnt;
- e[cnt].u=u;e[cnt].v=v;
- }
- inline void pre(int u,int f)
- {
- for(RT int i=1;(1<<i)<=dep[u];++i)
- fa[u][i]=fa[fa[u][i-1]][i-1];
- for(RT int i=first[u];i;i=e[i].nxt)
- {
- int v=e[i].v;
- if(v==f)continue;
- dep[v]=dep[u]+1;
- fa[v][0]=u;
- pre(v,u);
- }
- }
- inline int lca(int x,int y)
- {
- if(dep[x]<dep[y])swap(x,y);
- int t=dep[x]-dep[y];
- for(RT int i=0;(1<<i)<=t;++i)
- if((1<<i)&t)
- x=fa[x][i];
- if(x==y)return x;
- for(RT int i=19;i>=0;--i)
- if(fa[x][i]!=fa[y][i])
- x=fa[x][i],y=fa[y][i];
- return fa[x][0];
- }
- inline void dfs(int u,int f)
- {
- int st1=A[w[u]+dep[u]],st2=B[w[u]-dep[u]+n];
- for(RT int i=first[u];i;i=e[i].nxt)
- {
- int v=e[i].v;
- if(v==f)continue;
- dfs(v,u);
- }
- for(RT int i=0;i<v[u].size();i++)
- {
- update x=v[u][i];
- if(x.c==1)A[x.id]+=x.tag;
- else B[x.id+n]+=x.tag;
- }
- ans[u]=A[w[u]+dep[u]]+B[w[u]-dep[u]+n]-st1-st2;
- }
- int main()
- {
- read(n);read(m);
- for(RT int i=1;i<n;++i)
- {
- int u,v,w;
- read(u),read(v);
- add(u,v);add(v,u);
- }
- pre(1,0);
- for(RT int i=1;i<=n;++i)read(w[i]);
- for(RT int i=1;i<=m;++i)
- {
- int s,t,LCA;
- read(s),read(t);LCA=lca(s,t);
- v[s].push_back({1,1,dep[s]});v[fa[LCA][0]].push_back({1,-1,dep[s]});
- v[t].push_back({2,1,dep[s]-2*dep[LCA]});v[LCA].push_back({2,-1,dep[s]-2*dep[LCA]});
- }
- dfs(1,0);
- for(RT int i=1;i<=n;++i)print(ans[i]),putchar(' ');
- return 0;
- }