树链剖分,将树上的一段路径划分为log条重链,用线段树统计答案。
dfs2时先遍历重儿子,遍历轻儿子时注意判重,注意重新分配的编号 。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 1e5 + 100;
struct Edge {
int v, nxt;
} e[N << 1];
struct Node {
int l, r;
int tag;
int sum;
} tr[N << 2];
int n, m, cnt, head[N], weight[N], wgt[N], rt, mod;
int siz[N], fa[N], dep[N], son[N], top[N], id[N], idx;
void AddEdge(int u, int v) {
e[++cnt].v = v;
e[cnt].nxt = head[u];
head[u] = cnt;
}
void dfs1(int u, int ff) {
fa[u] = ff;
dep[u] = dep[ff] + 1;
siz[u] = 1;
int maxson = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].v;
if( v == fa[u]) continue;
dfs1(v, u);
siz[u] += siz[v];
if( siz[v] > maxson)
son[u] = v;
}
}
void dfs2(int u, int topf) {
top[u] = topf;
id[u] = ++ idx;
wgt[idx] = weight[u];
if( son[u]) dfs2(son[u], topf);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].v;
if( v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
void pushup(int u) {
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % mod;
}
void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if( l == r) {
tr[u].sum = wgt[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void pushdown(int u) {
if( tr[u].tag) {
Node & ls = tr[u << 1], &rs = tr[u << 1 | 1];
ls.sum = (ls.sum + (ls.r - ls.l + 1) * tr[u].tag) % mod;
ls.tag = (tr[u].tag + ls.tag) % mod;
rs.sum = (rs.sum + (rs.r - rs.l + 1) * tr[u].tag) % mod;
rs.tag = (tr[u].tag + rs.tag) % mod;
tr[u].tag = 0;
}
}
void change(int u, int l, int r, int v) {
if( l <= tr[u].l && r >= tr[u].r) {
tr[u].tag = (tr[u].tag + v) % mod;
tr[u].sum = (tr[u].sum + v * (tr[u].r - tr[u].l + 1)) % mod;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if( l <= mid) change(u << 1, l, r, v);
if( r > mid) change(u << 1 | 1, l, r, v);
pushup(u);
}
int ask(int u, int l, int r) {
if( l <= tr[u].l && r >= tr[u].r) {
return tr[u].sum;
}
pushdown(u);
int res = 0;
int mid = tr[u].l + tr[u].r >> 1;
if( l <= mid) res += ask(u << 1, l, r);
if( r > mid) res = (res + ask(u << 1 | 1, l, r));
return res % mod;
}
int query(int x, int y) {
int res = 0;
int tp1 = top[x], tp2 = top[y];
while( top[x] != top[y]) {
if(dep[tp1] < dep[tp2]) swap(tp1, tp2), swap(x, y);
res = (res + ask(1, id[tp1], id[x])) % mod;
x = fa[tp1];
tp1 = top[x];
}
if( dep[x] < dep[y]) swap(x, y);
return (res + ask(1, id[y], id[x])) % mod;
}
void modify(int x, int y, int v) {
int tp1 = top[x], tp2 = top[y];
while(top[x] != top[y]) {
if( dep[tp1] < dep[tp2]) swap(x, y), swap(tp1, tp2);
change(1, id[tp1], id[x], v);
x = fa[tp1];
tp1 = top[x];
}
if( dep[x] < dep[y]) swap(x, y);
change(1, id[y], id[x], v);
}
int main()
{
scanf("%d%d%d%d", &n, &m, &rt, &mod);
for(int i = 1; i <= n; i ++)
scanf("%d", &weight[i]);
for(int i = 1; i < n; i ++) {
int u, v;
scanf("%d%d", &u, &v);
AddEdge(u, v), AddEdge(v, u);
}
dfs1(rt, 0);
dfs2(rt, rt);
build(1, 1, n);
while(m --) {
int opt, x, y, z;
scanf("%d", &opt);
if( opt == 1) {
scanf("%d%d%d", &x, &y, &z);
modify(x, y, z);
}
else if( opt == 2) {
scanf("%d%d", &x, &y);
printf("%d
", query(x, y));
}
else if( opt == 3) {
scanf("%d%d", &x, &z);
change(1, id[x], id[x] + siz[x] - 1, z);
}
else {
scanf("%d", &x);
printf("%d
", ask(1, id[x], id[x] + siz[x] - 1));
}
}
return 0;
}