题目描述
思路
树链剖分更新子树x,由于线段树节点的顺序由dfs产生,所以更新的线段树区间为[dfn[x], dfn[x] + size[x] - 1]
代码
#include <cstdio>
#include <cstring>
#define lc k<<1
#define rc k<<1|1
const int MAX = 1e5 + 10;
int n, m, ot, oa[100];
int head[MAX], ver[MAX << 1], nt[MAX << 1], ht;
int wt[MAX];
int fa[MAX], dep[MAX], size[MAX], son[MAX];
int top[MAX], dfn[MAX], tr[MAX], dt;
long long sum[MAX << 2], add[MAX << 2], ans;
char showStr[100];
inline int read() {
int s = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-') f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s * f;
}
inline void write(long long x) {
ot = 0;
if (x == 0) { putchar('0'); return; }
if (x < 0) putchar('-'), x = -x;
while (x) oa[++ot] = x % 10 + '0', x /= 10;
while (ot) putchar(oa[ot--]);
}
void add_edge(int x, int y) {
nt[++ht] = head[x], head[x] = ht, ver[ht] = y;
}
void dfs1(int x, int u) {
fa[x] = u;
dep[x] = dep[u] + 1;
size[x] = 1;
for (int i = head[x], j; i; i = nt[i]) {
j = ver[i];
if (j == u) continue;
dfs1(j, x);
size[x] += size[j];
if (size[j] > size[son[x]]) son[x] = j;
}
}
void dfs2(int x, int u) {
top[x] = u;
dfn[x] = ++dt;
tr[dt] = x;
if (son[x]) dfs2(son[x], u);
for (int i = head[x], j; i; i = nt[i]) {
j = ver[i];
if (!dfn[j]) dfs2(j, j);
}
}
void build(int k, int l, int r) {
if (l == r) { sum[k] = wt[tr[l]]; return; }
int mid = l + r >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
sum[k] = sum[lc] + sum[rc];
}
void pushdown(int k, int l, int r, int mid) {
if (add[k] == 0) return;
sum[lc] += (mid - l + 1) * add[k], add[lc] += add[k];
sum[rc] += (r - mid) * add[k], add[rc] += add[k];
add[k] = 0;
}
void change(int k, int l, int r, int x, int y, int z) {
if (x <= l && r <= y) {
sum[k] += (long long)(r - l + 1) * z;
add[k] += z;
return;
}
int mid = l + r >> 1;
pushdown(k, l, r, mid);
if (x <= mid) change(lc, l, mid, x, y, z);
if (y > mid) change(rc, mid + 1, r, x, y, z);
sum[k] = sum[lc] + sum[rc];
}
void swap(int &x, int &y) {
int t = x;
x = y, y = t;
}
void query(int k, int l, int r, int x, int y) {
// printf("query: %d %d %d %d %d %d
", k, l, r, x, y, ans);
if (x <= l && r <= y) { ans += sum[k]; return; }
int mid = l + r >> 1;
pushdown(k, l, r, mid);
if (x <= mid) query(lc, l, mid, x, y);
if (y > mid) query(rc, mid + 1, r, x, y);
}
void ask(int x, int y) {
ans = 0LL;
int fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
query(1, 1, n, dfn[fx], dfn[x]);
x = fa[fx], fx = top[x];
}
if (dep[x] > dep[y]) swap(x, y);
query(1, 1, n, dfn[x], dfn[y]);
}
void showArray(int * arr) {
puts(showStr);
for (int i = 1; i <= n; ++i) printf("%2d ", arr[i]);
puts("");
}
void show() {
printf("n:%d m:%d
", n, m);
for (int i = 1; i <= n; ++i) {
printf("%d:", i);
for (int j = head[i]; j; j = nt[j]) {
printf("%d ", ver[j]);
}
puts("");
}
strcpy(showStr, "wt :"), showArray(wt);
strcpy(showStr, "fa :"), showArray(fa);
strcpy(showStr, "size:"), showArray(size);
strcpy(showStr, "dep :"), showArray(dep);
strcpy(showStr, "son :"), showArray(son);
strcpy(showStr, "dfn :"), showArray(dfn);
strcpy(showStr, "tr :"), showArray(tr);
strcpy(showStr, "top :"), showArray(top);
}
int main() {
n = read(), m = read();
for (int i = 1; i <= n; ++i) wt[i] = read();
for (int i = 1, a, b; i < n; ++i) {
a = read(), b = read(),
add_edge(a, b), add_edge(b, a);
}
dfs1(1, 0);
dfs2(1, 1);
// show();
build(1, 1, n);
for (int i = 1, j, a, b; i <= m; ++i) {
j = read();
// printf("%d ", j);
switch(j) {
case 1:
a = read(), b = read();
// printf("%d %d
", a, b);
change(1, 1, n, dfn[a], dfn[a], b);
break;
case 2:
// printf("%d %d
", a, b);
a = read(), b = read();
change(1, 1, n, dfn[a], dfn[a] + size[a] - 1, b);
break;
case 3:
a = read();
// printf("%d
", a);
ask(1, a);
printf("%lld
", ans);
// write(ans);
puts("");
break;
}
}
return 0;
}