意义:
树链剖分 就是对一棵树分成几条链,把树形变为线性,减少处理难度
概念
重儿子:对于每一个非叶子节点,它的儿子中 儿子数量最多的那一个儿子 为该节点的重儿子
轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子
叶子节点没有重儿子也没有轻儿子(因为它没有儿子。。)
重边:连接任意两个重儿子的边叫做重边
轻边:剩下的即为轻边
重链:相邻重边连起来的 连接一条重儿子 的链叫重链
对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链
每一条重链以轻儿子为起点
题目大意:
给定一棵有根树,给定每个点初值。 需要处理的问题:
将树从x到y结点最短路径上所有节点的值都加上z
求树从x到y结点最短路径上所有节点的值之和
将以x为根节点的子树内所有节点值都加上z
求以x为根节点的子树内所有节点值之和
分析:
树链剖分+线段树
树剖部分:
需要数组:
int root,n,m,p; int dfn[N],dfn2[N],fdfn[N]; int top[N],son[N],fa[N],dep[N],size[N];
1.dfs1:
目标:
①找到fa,重儿子(son)
②处理节点深度,子树大小(size)(dep[root]=1,fa[root]=-1,其实本题不固定)
void dfs1(int x,int f,int d) { dep[x]=d; size[x]=1; int mx=0; for(int i=head[x];i;i=bian[i].nxt) { int y=bian[i].to; if(y==f) continue;//不能回走 fa[y]=x; dfs1(y,x,d+1); size[x]+=size[y]; if(size[y]>mx) { mx=size[y],son[x]=y;//记录重儿子 } } }
2.dfs2
目标:
①找到dfn,dfn2(子树结尾dfn)便于之后线段树维护区间。
②处理fdfn,记录dfnx是几号点。便于线段树build
③注意:有重儿子,先走重儿子。
结果:
dfn数组中,一棵完整的子树,其dfn也是连续的一段。每条重链也是连续的一段。这样,用线段树很方便维护树上路径的处理。
void dfs2(int x,int f) { dfn[x]=++tot; fdfn[tot]=x;//第tot个dfn是x号 if(!top[x]) top[x]=x;//top未赋值才能赋值 if(son[x]) top[son[x]]=top[x],dfs2(son[x],x);//先走重儿子 for(int i=head[x];i;i=bian[i].nxt) { int y=bian[i].to; if(y==son[x]||y==f) continue; dfs2(y,x); } dfn2[x]=tot;//回溯之前记录下子树结尾dfn }
此处省去线段树常规操作,详见下面代码。
3.work1
利用树剖lca想法,其中一个点一边向上翻的同时,更新值。最后在同一条链上了之后,相当于已经找到了lca直接更新另一条路径。
void work1(int x,int y,int z) { while(top[x]!=top[y]) { if(dep[top[x]]>dep[top[y]]) swap(x,y);//dep[top]深度深的向上翻 add(1,1,tot,dfn[top[y]],dfn[y],z); y=fa[top[y]]; } if(dep[x]>dep[y]) swap(x,y); add(1,1,tot,dfn[x],dfn[y],z);//另一边路径 }
work2同理。
4.work3,work4,利用之前记录过的dfn2,可以直接找到子树区间。直接处理即可。
void work3(int x,int z) { add(1,1,tot,dfn[x],dfn2[x],z); } int work4(int x) { int sum=0; sum=(sum+query(1,1,tot,dfn[x],dfn2[x]))%p; return sum; }
注意事项:
1.每次dfs注意不要返祖。
2.记得取模!!!任何加减,赋值,求和都要提起注意。
3.区间add标记直接加,sum要+c×(len)必须乘区间!!(线段树不过关。。。)
4.root是原来树的根,线段树的根就是1!!(不要混了)RE无数无数无数
详见代码:
#include<bits/stdc++.h> using namespace std; const int N=1e5+10; int a[N]; int root,n,m,p; int dfn[N],dfn2[N],fdfn[N]; int top[N],son[N],fa[N],dep[N],size[N]; struct node{ int nxt,to; }bian[2*N]; int cnt,tot; int head[N]; void add(int x,int y) { bian[++cnt].nxt=head[x]; bian[cnt].to=y; head[x]=cnt; } void dfs1(int x,int f,int d) { dep[x]=d; size[x]=1; int mx=0; for(int i=head[x];i;i=bian[i].nxt) { int y=bian[i].to; if(y==f) continue; fa[y]=x; dfs1(y,x,d+1); size[x]+=size[y]; if(size[y]>mx) { mx=size[y],son[x]=y; } } } void dfs2(int x,int f) { dfn[x]=++tot; fdfn[tot]=x; if(!top[x]) top[x]=x; if(son[x]) top[son[x]]=top[x],dfs2(son[x],x); for(int i=head[x];i;i=bian[i].nxt) { int y=bian[i].to; if(y==son[x]||y==f) continue; dfs2(y,x); } dfn2[x]=tot; } //-------------------以上树剖 ----------------------------------- int mod(int x) { while(x>=p) x-=p; while(x<0) x+=p; return x; } struct tree{ int sum,add; #define s(x) t[x].sum #define ad(x) t[x].add }t[4*N]; void pushup(int x) { s(x)=mod(s(x<<1)+s(x<<1|1)); } void build(int x,int l,int r) { if(l==r) { s(x)=mod(a[fdfn[l]]);ad(x)=0; return; } int mid=(l+r)>>1; build(x<<1,l,mid); build(x<<1|1,mid+1,r); pushup(x); } void pushdown(int x,int l,int r)//change sum+=ad*len { int s1=x<<1,s2=x<<1|1; int mid=(l+r)>>1; ad(s1)=mod(ad(s1)+ad(x)); s(s1)=mod(s(s1)+ad(x)*(mid-l+1)); ad(s2)=mod(ad(s2)+ad(x)); s(s2)=mod(s(s2)+ad(x)*(r-mid)); ad(x)=0; } void add(int x,int l,int r,int L,int R,int c) { if(L<=l&&r<=R) { s(x)=mod(s(x)+mod(c*(r-l+1))); ad(x)=mod(ad(x)+c); return; } pushdown(x,l,r); int mid=(l+r)>>1; if(L<=mid) add(x<<1,l,mid,L,R,c); if(mid<R) add(x<<1|1,mid+1,r,L,R,c); pushup(x); } int query(int x,int l,int r,int L,int R) { if(L<=l&&r<=R) { return s(x); } pushdown(x,l,r); int mid=(l+r)>>1; int res=0; if(L<=mid) res=mod(res+query(x<<1,l,mid,L,R)); if(mid<R) res=mod(res+query(x<<1|1,mid+1,r,L,R)); return res; } //-------------------以上线段树 ----------------------------------- void work1(int x,int y,int z) { while(top[x]!=top[y]) { if(dep[top[x]]>dep[top[y]]) swap(x,y); add(1,1,tot,dfn[top[y]],dfn[y],z); y=fa[top[y]]; } if(dep[x]>dep[y]) swap(x,y); add(1,1,tot,dfn[x],dfn[y],z); } int work2(int x,int y) { int sum=0; while(top[x]!=top[y]) { if(dep[top[x]]>dep[top[y]]) swap(x,y); sum=(sum+query(1,1,tot,dfn[top[y]],dfn[y]))%p; y=fa[top[y]]; } if(dep[x]>dep[y]) swap(x,y); sum=(sum+query(1,1,tot,dfn[x],dfn[y]))%p; return sum; } void work3(int x,int z) { add(1,1,tot,dfn[x],dfn2[x],z); } int work4(int x) { int sum=0; sum=(sum+query(1,1,tot,dfn[x],dfn2[x]))%p; return sum; } int main() { scanf("%d%d%d%d",&n,&m,&root,&p); for(int i=1;i<=n;i++) scanf("%d",&a[i]); int x,y; for(int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); add(x,y);add(y,x); } dfs1(root,-1,1); dfs2(root,-1); fa[root]=-1; build(1,1,tot); int op,z; while(m) { scanf("%d",&op); if(op==1) { scanf("%d%d%d",&x,&y,&z); work1(x,y,z); } else if(op==2) { scanf("%d%d",&x,&y); printf("%d ",work2(x,y)); } else if(op==3) { scanf("%d%d",&x,&z); work3(x,z); } else{ scanf("%d",&x); printf("%d ",work4(x)); } m--; } return 0; }