树链剖分真是难写极了:你需要先熟练掌握深搜、线段树、倍增。。。(都是不好写的东西啊,一手滑就wa)
一、简介
树链剖分通常用于维护静态树上路径信息的问题。树链剖分的核心就是将数分为若干重链、轻链,然后把他们当做序列,按顺序拼接起来,处理序列上的区间问题
二、相关量
Fa【x】:x在树中的父亲(用于倍增)
Dep【x】:x在树中的深度(用于倍增)
Size【x】:x的子树节点数(用于线段树操作)
Son【x】:x的重儿子,u --> son【u】为重边(用于建重边)
Top【x】:x所在重路径的顶部节点(深度最小)(用于倍增)
Id【x】:x在线段树中的位置(下标)(用于查询回答等各种操作)
Nw【id【x】】:线段树中第id【x】个位置对应树中节点编号,即nw【id【x】】 = x(用于查询回答等各种操作)
三、流程
第一遍dfs处理出fa,dep,size,son
第二遍dfs处理出top,id,nw(先处理重边再处理轻边)
int dep[maxn * 4],fa[maxn * 4],siz[maxn * 4],son[maxn * 4]; void dfs1(int x,int f,int deep) { dep[x] = deep; fa[x] = f; siz[x] = 1; int maxson = -1; for(int i = head[x]; i; i = nxt[i]) { int too = to[i]; if(too == f)continue; dfs1(too,x,deep + 1); siz[x] += siz[too]; if(siz[too] > maxson) { son[x] = too; maxson = siz[too]; } } }int id[maxn * 4],idx,top[maxn * 4],nw[maxn * 4]; void dfs2(int x,int topf) { id[x] = ++idx; nw[idx] = a[x]; top[x] = topf; if(!son[x])return; dfs2(son[x],topf); for(int i = head[x]; i; i = nxt[i]) { int too = to[i]; if(too == fa[x] || too == son[x])continue; dfs2(too,too); } }
拆分成若干重路径倍增处理成若干个线段树上区间操作
int qrange(int ll,int rr) { int anss = 0; while(top[ll] != top[rr]) { if(dep[top[ll]] < dep[top[rr]])swap(ll,rr); ans = 0; query(1,n,id[top[ll]],id[ll],1); anss += ans; anss %= mod; ll = fa[top[ll]]; } if(dep[ll] > dep[rr])swap(ll,rr); ans = 0; query(1,n,id[ll],id[rr],1); anss += ans; anss %= mod; return anss; } void addrange(int ll,int rr,int x) { x %= mod; while(top[ll] != top[rr])//倍增求和 { if(dep[top[ll]] < dep[top[rr]])swap(ll,rr); addchange(1,n,id[top[ll]],id[ll],x,1); ll = fa[top[ll]]; } if(dep[ll] > dep[rr])swap(ll,rr); addchange(1,n,id[ll],id[rr],x,1); } int qson(int ro) { ans = 0; query(1,n,id[ro],id[ro] + siz[ro] - 1,1); return ans; } void addson(int ro,int x) { addchange(1,n,id[ro],id[ro] + siz[ro] - 1,x,1); }
线段树操作
struct node { int w,laz; } tr[maxn * 8];int ans; void pushdown(int ro,int len) { tr[ro * 2].laz += tr[ro].laz; tr[ro * 2 + 1].laz += tr[ro].laz; tr[ro * 2].w += tr[ro].laz * (len - len / 2); tr[ro * 2 + 1].w += tr[ro].laz * (len / 2); tr[ro * 2].w %= mod; tr[ro * 2 + 1].w %= mod; tr[ro].laz = 0; } void build(int ll,int rr,int ro) { if(ll == rr) { tr[ro].w = nw[ll]; tr[ro].w %= mod; return ; } else { int mid = (ll + rr) / 2; build(ll,mid,ro * 2); build(mid + 1,rr,ro * 2 + 1); tr[ro].w = tr[ro * 2].w + tr[ro * 2 + 1].w; tr[ro].w %= mod; } } void query(int ll,int rr,int al,int ar,int ro) { if(al <= ll && rr <= ar) { ans += tr[ro].w; ans %= mod; return ; } else { int mid = (ll + rr) / 2; if(tr[ro].laz)pushdown(ro,rr - ll + 1); if(al <= mid)query(ll,mid,al,ar,ro * 2); if(ar > mid)query(mid + 1,rr,al,ar,ro * 2 + 1); } }
四、复杂度
o(log^2n)
五、题
1 #include<cstdio> 2 #include<iostream> 3 using namespace std; 4 #define maxn 200005 5 int n,m,r,mod; 6 int to[maxn * 4],head[maxn * 4],nxt[maxn * 4],cnt; 7 int a[maxn]; 8 void add(int u,int v) 9 { 10 to[++cnt] = v; 11 nxt[cnt] = head[u]; 12 head[u] = cnt; 13 } 14 int dep[maxn * 4],fa[maxn * 4],siz[maxn * 4],son[maxn * 4]; 15 void dfs1(int x,int f,int deep) 16 { 17 dep[x] = deep; 18 fa[x] = f; 19 siz[x] = 1; 20 int maxson = -1; 21 for(int i = head[x]; i; i = nxt[i]) 22 { 23 int too = to[i]; 24 if(too == f)continue; 25 dfs1(too,x,deep + 1); 26 siz[x] += siz[too]; 27 if(siz[too] > maxson) 28 { 29 son[x] = too; 30 maxson = siz[too]; 31 } 32 } 33 } 34 int id[maxn * 4],idx,top[maxn * 4],nw[maxn * 4]; 35 void dfs2(int x,int topf) 36 { 37 id[x] = ++idx; 38 nw[idx] = a[x]; 39 top[x] = topf; 40 if(!son[x])return; 41 dfs2(son[x],topf); 42 for(int i = head[x]; i; i = nxt[i]) 43 { 44 int too = to[i]; 45 if(too == fa[x] || too == son[x])continue; 46 dfs2(too,too); 47 } 48 } 49 struct node 50 { 51 int w,laz; 52 } tr[maxn * 8]; 53 int ans; 54 void pushdown(int ro,int len) 55 { 56 tr[ro * 2].laz += tr[ro].laz; 57 tr[ro * 2 + 1].laz += tr[ro].laz; 58 tr[ro * 2].w += tr[ro].laz * (len - len / 2); 59 tr[ro * 2 + 1].w += tr[ro].laz * (len / 2); 60 tr[ro * 2].w %= mod; 61 tr[ro * 2 + 1].w %= mod; 62 tr[ro].laz = 0; 63 } 64 void build(int ll,int rr,int ro) 65 { 66 if(ll == rr) 67 { 68 tr[ro].w = nw[ll]; 69 tr[ro].w %= mod; 70 return ; 71 } 72 else 73 { 74 int mid = (ll + rr) / 2; 75 build(ll,mid,ro * 2); 76 build(mid + 1,rr,ro * 2 + 1); 77 tr[ro].w = tr[ro * 2].w + tr[ro * 2 + 1].w; 78 tr[ro].w %= mod; 79 } 80 } 81 void query(int ll,int rr,int al,int ar,int ro) 82 { 83 if(al <= ll && rr <= ar) 84 { 85 ans += tr[ro].w; 86 ans %= mod; 87 return ; 88 } 89 else 90 { 91 int mid = (ll + rr) / 2; 92 if(tr[ro].laz)pushdown(ro,rr - ll + 1); 93 if(al <= mid)query(ll,mid,al,ar,ro * 2); 94 if(ar > mid)query(mid + 1,rr,al,ar,ro * 2 + 1); 95 } 96 97 } 98 void addchange(int ll,int rr,int al,int ar,int x,int ro) 99 { 100 if(al <= ll && rr <= ar) 101 { 102 tr[ro].laz += x; 103 tr[ro].w += x * (rr - ll + 1); 104 return ; 105 } 106 else 107 { 108 int mid = (ll + rr) / 2; 109 if(tr[ro].laz)pushdown(ro,rr - ll + 1); 110 if(al <= mid)addchange(ll,mid,al,ar,x,ro * 2); 111 if(ar > mid)addchange(mid + 1,rr,al,ar,x,ro * 2 + 1); 112 tr[ro].w = tr[ro * 2].w + tr[ro * 2 + 1].w; 113 tr[ro].w %= mod; 114 } 115 116 } 117 int qrange(int ll,int rr) 118 { 119 int anss = 0; 120 while(top[ll] != top[rr]) 121 { 122 if(dep[top[ll]] < dep[top[rr]])swap(ll,rr); 123 ans = 0; 124 query(1,n,id[top[ll]],id[ll],1); 125 anss += ans; 126 anss %= mod; 127 ll = fa[top[ll]]; 128 } 129 if(dep[ll] > dep[rr])swap(ll,rr); 130 ans = 0; 131 query(1,n,id[ll],id[rr],1); 132 anss += ans; 133 anss %= mod; 134 return anss; 135 } 136 void addrange(int ll,int rr,int x) 137 { 138 x %= mod; 139 while(top[ll] != top[rr])//倍增求和 140 { 141 if(dep[top[ll]] < dep[top[rr]])swap(ll,rr); 142 addchange(1,n,id[top[ll]],id[ll],x,1); 143 ll = fa[top[ll]]; 144 } 145 if(dep[ll] > dep[rr])swap(ll,rr); 146 addchange(1,n,id[ll],id[rr],x,1); 147 } 148 int qson(int ro) 149 { 150 ans = 0; 151 query(1,n,id[ro],id[ro] + siz[ro] - 1,1); 152 return ans; 153 } 154 void addson(int ro,int x) 155 { 156 addchange(1,n,id[ro],id[ro] + siz[ro] - 1,x,1); 157 } 158 int main() 159 { 160 scanf("%d%d%d%d",&n,&m,&r,&mod); 161 for(int i = 1; i <= n; i ++)scanf("%d",&a[i]); 162 for(int i = 1; i < n; i ++) 163 { 164 int u,v; 165 scanf("%d%d",&u,&v); 166 add(u,v); 167 add(v,u); 168 } 169 dfs1(r,0,1); 170 dfs2(r,r); 171 build(1,n,1); 172 for(int i = 1; i <= m; i ++) 173 { 174 int k; 175 scanf("%d",&k); 176 if(k == 1)//x到y的最短路径上所有节点加z 177 { 178 int x,y,z; 179 scanf("%d%d%d",&x,&y,&z); 180 addrange(x,y,z); 181 } 182 else if(k == 2)//求x到y的最短路径上所有节点之和 183 { 184 int x,y; 185 scanf("%d%d",&x,&y); 186 printf("%d ",qrange(x,y)); 187 } 188 else if(k == 3)//以x为根节点的子树所有节点加z 189 { 190 int x,z; 191 scanf("%d%d",&x,&z); 192 addson(x,z); 193 } 194 else if(k == 4)//求以x为根节点的子树所有节点值之和 195 { 196 int x; 197 scanf("%d",&x); 198 printf("%d ",qson(x)); 199 } 200 } 201 return 0; 202 }