BZOJ1036
到此为止我们已经熟悉了静态区间的询问操作,我们还在此基础上将问题拓展到了二维并给出了部分的解决方案
然后,我把把区间变成一棵树,对于树上的询问,比如求树上两点之间最值,树上两点之间的点权之和,我们有固定的解决方案:树链剖分
这两个问题,也是题目问我们的
把树打成链再在链上查询是一种解决问题的方法,那么打成链的方式里面,最最最常见的是轻重链剖分,当然也可能有别的类型的剖分,因题而异
我们对于打成的链,用线段树来维护,因为我们的树的形态不会发生改变,那么我们的线段树和原始的树是对应起来的
int n,cnt,sz,q; int head[maxn],v[maxn],size[maxn],fa[maxn],dep[maxn],pos[maxn],bl[maxn]; struct edge{int to,next;}e[maxm]; struct seg{int l,r,mx,sum;}t[100005];
这里n是树上节点总数,cnt是边的数量(一定注意,无向图二倍于点的边数),sz是一个计数变量,它和pos数组配合使用,它的作用是使pos[x]对应于线段树所映射区间的一个点,这样的话,pos[x]=sz++
这样我们在线段树中找树上点x时,直接pos[x]就好了
然后head是我们邻接表的配套数组,然后后面的v是树上节点点权,size是每一个节点所引出子树的节点数,用于轻重链剖分,然后dep是节点深度,bl[x]是节点x所在重链的根节点(链就是说的头节点了)
edge不用说,邻接链表的结构体,seg是线段树
接下来,就要把树拆成链了
两个DFS解决,我们先看第一个DFS:
void dfs1(int x) { size[x]=1; for(int i=head[x];i;i=e[i].next) { if(e[i].to==fa[x]) continue; dep[e[i].to]=dep[x]+1; fa[e[i].to]=x; dfs1(e[i].to); size[x]+=size[e[i].to]; } }
第一遍dfs求出树每个结点的深度dep[x],其为根的子树大小size[x]以及每一个节点的daddy:fa[x]
void dfs2(int x,int chain) { int k=0;sz++; pos[x]=sz; bl[x]=chain; //x节点所在重链的根 for(int i=head[x];i;i=e[i].next) if(dep[e[i].to]>dep[x]&&size[e[i].to]>size[k]) k=e[i].to; if(k==0) return; dfs2(k,chain); for(int i=head[x];i;i=e[i].next) if(dep[e[i].to]>dep[x]&&k!=e[i].to) dfs2(e[i].to,e[i].to); }
第二遍,根节点为起点,向下拓展构建重链,选择最大的一个子树的根继承当前重链,其余节点都以该节点为起点向下重新拉一条重链
接下来是线段树的建树(这里建了个空树)和点修改,这里不再赘述
我们给出求树上区间点权的最值以及和的函数:
int solvemx(int x,int y) { int mx=-INF; while(bl[x]!=bl[y]) { if(dep[bl[x]]<dep[bl[y]]) swap(x,y); //x在上,y在下 mx=max(mx,querymx(1,pos[bl[x]],pos[x])); x=fa[bl[x]]; } if(pos[x]>pos[y]) swap(x,y); mx=max(mx,querymx(1,pos[x],pos[y])); return mx; } int solvesum(int x,int y) { int sum=0; while(bl[x]!=bl[y]) { if(dep[bl[x]]<dep[bl[y]]) swap(x,y); sum+=querysum(1,pos[bl[x]],pos[x]); x=fa[bl[x]]; } if(pos[x]>pos[y]) swap(x,y); sum+=querysum(1,pos[x],pos[y]); return sum; }
分两种情况,若u和v在同一条重链上,直接用数据结构修改pos[u]至pos[v]间的值,若u和v不在同一条重链上,一边进行修改,一边将u和v往同一条重链上靠,u和v就在同一条重链上面了,然后就显然了
下面给出完整的实现,线段树部分,可以换成带lazy tag的版本,我在其他文章中有介绍:
1 #include<cstdio> 2 #include<algorithm> 3 using namespace std; 4 const int maxn=30005; 5 const int maxm=60005; 6 const int INF=0x7fffffff; 7 int n,cnt,sz,q; 8 int head[maxn],v[maxn],size[maxn],fa[maxn],dep[maxn],pos[maxn],bl[maxn]; 9 struct edge{int to,next;}e[maxm]; 10 struct seg{int l,r,mx,sum;}t[100005]; 11 void insert(int u,int v) 12 { 13 e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt; 14 e[++cnt].to=u;e[cnt].next=head[v];head[v]=cnt; 15 } 16 void dfs1(int x) 17 { 18 size[x]=1; 19 for(int i=head[x];i;i=e[i].next) 20 { 21 if(e[i].to==fa[x]) continue; 22 dep[e[i].to]=dep[x]+1; 23 fa[e[i].to]=x; 24 dfs1(e[i].to); 25 size[x]+=size[e[i].to]; 26 } 27 } 28 void dfs2(int x,int chain) 29 { 30 int k=0;sz++; 31 pos[x]=sz; 32 bl[x]=chain; //x节点所在重链的根 33 for(int i=head[x];i;i=e[i].next) 34 if(dep[e[i].to]>dep[x]&&size[e[i].to]>size[k]) k=e[i].to; 35 if(k==0) return; 36 dfs2(k,chain); 37 for(int i=head[x];i;i=e[i].next) 38 if(dep[e[i].to]>dep[x]&&k!=e[i].to) dfs2(e[i].to,e[i].to); 39 } 40 void build(int k,int l,int r)//建线段树 41 { 42 t[k].l=l;t[k].r=r; 43 if(l==r)return; 44 int mid=(l+r)>>1; 45 build(k<<1,l,mid); 46 build(k<<1|1,mid+1,r); 47 } 48 void change(int k,int x,int y)//线段树单点修改 49 { 50 int l=t[k].l,r=t[k].r,mid=(l+r)>>1; 51 if(l==r){t[k].sum=t[k].mx=y;return;} 52 if(x<=mid)change(k<<1,x,y); 53 else change(k<<1|1,x,y); 54 t[k].sum=t[k<<1].sum+t[k<<1|1].sum; 55 t[k].mx=max(t[k<<1].mx,t[k<<1|1].mx); 56 } 57 int querysum(int k,int x,int y)//线段树区间求和 58 { 59 int l=t[k].l,r=t[k].r,mid=(l+r)>>1; 60 if(l==x&&y==r)return t[k].sum; 61 if(y<=mid)return querysum(k<<1,x,y); 62 else if(x>mid)return querysum(k<<1|1,x,y); 63 else {return querysum(k<<1,x,mid)+querysum(k<<1|1,mid+1,y);} 64 } 65 int querymx(int k,int x,int y)//线段树区间求最大值 66 { 67 68 int l=t[k].l,r=t[k].r,mid=(l+r)>>1; 69 if(l==x&&y==r)return t[k].mx; 70 if(y<=mid)return querymx(k<<1,x,y); 71 else if(x>mid)return querymx(k<<1|1,x,y); 72 else {return max(querymx(k<<1,x,mid),querymx(k<<1|1,mid+1,y));} 73 } 74 int solvemx(int x,int y) 75 { 76 int mx=-INF; 77 while(bl[x]!=bl[y]) 78 { 79 if(dep[bl[x]]<dep[bl[y]]) swap(x,y); //x在上,y在下 80 mx=max(mx,querymx(1,pos[bl[x]],pos[x])); 81 x=fa[bl[x]]; 82 } 83 if(pos[x]>pos[y]) swap(x,y); 84 mx=max(mx,querymx(1,pos[x],pos[y])); 85 return mx; 86 } 87 int solvesum(int x,int y) 88 { 89 int sum=0; 90 while(bl[x]!=bl[y]) 91 { 92 if(dep[bl[x]]<dep[bl[y]]) swap(x,y); 93 sum+=querysum(1,pos[bl[x]],pos[x]); 94 x=fa[bl[x]]; 95 } 96 if(pos[x]>pos[y]) swap(x,y); 97 sum+=querysum(1,pos[x],pos[y]); 98 return sum; 99 } 100 int main() 101 { 102 scanf("%d",&n); 103 for(int i=1;i<n;i++) 104 { 105 int x,y; 106 scanf("%d%d",&x,&y); 107 insert(x,y); 108 } 109 for(int i=1;i<=n;i++) scanf("%d",&v[i]); 110 dfs1(1); 111 dfs2(1,1); 112 build(1,1,n); 113 for(int i=1;i<=n;i++) change(1,pos[i],v[i]); 114 scanf("%d",&q); 115 char ch[10]; 116 for(int i=1;i<=q;i++) 117 { 118 int x,y;scanf("%s%d%d",ch,&x,&y); 119 if(ch[0]=='C') {v[x]=y;change(1,pos[x],y);} 120 else 121 { 122 if(ch[1]=='M') printf("%d ",solvemx(x,y)); 123 else printf("%d ",solvesum(x,y)); 124 } 125 } 126 return 0; 127 }