题目:https://www.luogu.org/problemnew/show/P1600
看TJ:https://blog.csdn.net/clove_unique/article/details/53427248
树上差分真好。
首先要发现向上和向下的……是定值。然后想到可以差分。
本题的差分略特殊之处在于它的对象是一条连到根的链。把链上的差分值记到非根的端点上。
总之看明白TJ之后感觉真精妙。
而且 从各种中选自己需要的 只需把所有都加进桶中!
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int N=3e5+5,M=9e5+15; int n,m,head[N],xnt,w[N],d[N],hd[N],xt,fa[N],tot0,tot1,dfn[N],tim; int cnt[M],fx=N,ans[N],nw,fatr[N]; struct Edge{ int next,to;Edge(int n=0,int t=0):next(n),to(t) {} }edge[N<<1]; struct Ed{ int next,to;bool fx;Ed(int n=0,int t=0,bool f=0):next(n),to(t),fx(f) {} }ed[N<<1]; struct Node{ int t,val,s;Node(int t=0,int v=0,int s=0):t(t),val(v),s(s) {} }a0[N<<1],a1[N<<1]; void add(int x,int y) { edge[++xnt]=Edge(head[x],y);head[x]=xnt; edge[++xnt]=Edge(head[y],x);head[y]=xnt; } void ad(int x,int y) { ed[++xt]=Ed(hd[x],y,0);hd[x]=xt; if(x!=y)ed[++xt]=Ed(hd[y],x,1);hd[y]=xt;//x!=y!! } int find(int a){return fa[a]==a?a:fa[a]=find(fa[a]);} bool cmp(Node x,Node y){return dfn[x.s]<dfn[y.s];} void build(int s,int t,int f) { a0[++tot0]=Node(0,1,s); if(f!=1)a0[++tot0]=Node(d[s]-d[f]+1,-1,fatr[f]);//f!=1 //用fatr a1[++tot1]=Node(d[s]-d[f]-d[f],-1,f); a1[++tot1]=Node(d[s]-d[f]-d[f],1,t);//不要+d[t]-d[f] } void dfs(int cr,int f) { d[cr]=d[f]+1;dfn[cr]=++tim;fatr[cr]=f; for(int i=hd[cr],v;i;i=ed[i].next) { if(dfn[v=ed[i].to]) { if(ed[i].fx)build(v,cr,find(v)); else build(cr,v,find(v)); } } for(int i=head[cr],v;i;i=edge[i].next) if((v=edge[i].to)!=f)dfs(v,cr),fa[v]=cr; } void dfs0(int cr,int f) { int pd=d[cr]+w[cr],cpy=cnt[pd]; while(nw<=tot0&&a0[nw].s==cr)cnt[a0[nw].t+d[cr]]+=a0[nw].val,nw++;//+=val for(int i=head[cr],v;i;i=edge[i].next) if((v=edge[i].to)!=f)dfs0(v,cr); ans[cr]+=cnt[pd]-cpy; } void dfs1(int cr,int f) { int pd=w[cr]-d[cr],cpy=cnt[pd+fx]; while(nw<=tot1&&a1[nw].s==cr)cnt[a1[nw].t+fx]+=a1[nw].val,nw++; for(int i=head[cr],v;i;i=edge[i].next) if((v=edge[i].to)!=f)dfs1(v,cr); ans[cr]+=cnt[pd+fx]-cpy; } int main() { scanf("%d%d",&n,&m);int x,y; for(int i=1;i<n;i++) { scanf("%d%d",&x,&y);add(x,y);fa[i]=i; } fa[n]=n; for(int i=1;i<=n;i++)scanf("%d",&w[i]); for(int i=1;i<=m;i++) { scanf("%d%d",&x,&y);ad(x,y); } d[0]=-1;dfs(1,0);sort(a0+1,a0+tot0+1,cmp); sort(a1+1,a1+tot1+1,cmp); nw=1;dfs0(1,0);memset(cnt,0,sizeof cnt); nw=1;dfs1(1,0); for(int i=1;i<=n;i++)printf("%d ",ans[i]); return 0; }
然后在洛谷上看到了“文文殿下”的代码。跑得好快!学了一下。
注意到更新cnt值的时候只用和自己节点有关的东西更新。
所以与其按dfs序排序然后在a[ ]上移动,不如给每个点记一个邻接表,指向a[ ]上的位置。
感觉人家对邻接表理解深刻。那个nxt是记在 t [ ] 的角标上的,hdhd提供一个指向 t [ ] 某个角标的入口一样的东西。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int N=3e5+5,M=9e5+15,Mm=6e5+15; int n,m,head[N],xnt,w[N],d[N],hd[N],xt,fa[N]; int cnt0[Mm],cnt1[M],fx=N,ans[N],fatr[N]; int hdhd0[N][2],hdhd1[N][2],t[N<<2],nxt[N<<2],tot; bool vis[N]; struct Edge{ int next,to;Edge(int n=0,int t=0):next(n),to(t) {} }edge[N<<1]; struct Ed{ int next,to;bool fx;Ed(int n=0,int t=0,bool f=0):next(n),to(t),fx(f) {} }ed[N<<1]; void add(int x,int y) { edge[++xnt]=Edge(head[x],y);head[x]=xnt; edge[++xnt]=Edge(head[y],x);head[y]=xnt; } void ad(int x,int y) { ed[++xt]=Ed(hd[x],y,0);hd[x]=xt; if(x!=y)ed[++xt]=Ed(hd[y],x,1),hd[y]=xt;//x!=y!! } int find(int a){return fa[a]==a?a:fa[a]=find(fa[a]);} void adad(int &x,int z){t[++tot]=z;nxt[tot]=x;x=tot;} void build(int s,int t,int f) { adad(hdhd0[s][0],0);if(f!=1)adad(hdhd0[fatr[f]][1],d[s]-d[f]+1); adad(hdhd1[f][1],d[s]-2*d[f]);adad(hdhd1[t][0],d[s]-2*d[f]); } void dfs(int cr,int f) { d[cr]=d[f]+1;vis[cr]=1;fatr[cr]=f; for(int i=hd[cr],v;i;i=ed[i].next) if(vis[v=ed[i].to]) if(ed[i].fx)build(v,cr,find(v)); else build(cr,v,find(v)); for(int i=head[cr],v;i;i=edge[i].next) if((v=edge[i].to)!=f)dfs(v,cr),fa[v]=cr; } void dfsx(int cr,int f) { int pd0=d[cr]+w[cr],pd1=w[cr]-d[cr]; ans[cr]-=cnt0[pd0]+cnt1[pd1+fx]; for(int i=hdhd0[cr][0];i;i=nxt[i])cnt0[t[i]+d[cr]]++; for(int i=hdhd0[cr][1];i;i=nxt[i])cnt0[t[i]+d[cr]]--; for(int i=hdhd1[cr][0];i;i=nxt[i])cnt1[t[i]+fx]++; for(int i=hdhd1[cr][1];i;i=nxt[i])cnt1[t[i]+fx]--; for(int i=head[cr];i;i=edge[i].next) if(edge[i].to!=f)dfsx(edge[i].to,cr); ans[cr]+=cnt0[pd0]+cnt1[pd1+fx]; } int main() { scanf("%d%d",&n,&m);int x,y; for(int i=1;i<n;i++) { scanf("%d%d",&x,&y);add(x,y);fa[i]=i; } fa[n]=n; for(int i=1;i<=n;i++)scanf("%d",&w[i]); for(int i=1;i<=m;i++) { scanf("%d%d",&x,&y);ad(x,y); } d[0]=-1;dfs(1,0);dfsx(1,0); for(int i=1;i<=n;i++)printf("%d ",ans[i]); return 0; }