题意:
给出N个节点的一棵树,每个节点有一个权值。
有两种操作:
1) 0 i j, 问节点i->节点j的路上的总权值。
2) 1 i v, 把节点i的权值改变成v。
思路:
权值在点上的树链剖分+线段树单点更新+线段树成段询问
1 #include <iostream> 2 #include <cstring> 3 #include <string> 4 #include <cstdio> 5 #include <vector> 6 using namespace std; 7 int n, q; 8 #define maxn 30010 9 #define lson l, m, rt<<1 10 #define rson m+1, r, rt<<1|1 11 int siz[maxn], top[maxn], fa[maxn], son[maxn], dep[maxn]; 12 int w[maxn], fw[maxn]; 13 int A[maxn]; 14 vector <int> mp[maxn]; 15 int pos; 16 int sum[maxn<<2]; 17 void PushUp(int rt) 18 { 19 sum[rt] = sum[rt<<1] + sum[rt<<1|1]; 20 } 21 void build(int l, int r, int rt) 22 { 23 sum[rt] = 0; 24 if(l == r) 25 { 26 sum[rt] = A[fw[l]]; 27 return; 28 } 29 int m = (l+r)>>1; 30 build(lson); 31 build(rson); 32 PushUp(rt); 33 } 34 void update(int p, int val, int l, int r, int rt) 35 { 36 if(l == r) 37 { 38 sum[rt] = val; 39 return; 40 } 41 int m = (l+r)>>1; 42 if(p <= m) update(p, val, lson); 43 else update(p, val, rson); 44 PushUp(rt); 45 } 46 int query_sum(int L, int R, int l, int r, int rt) 47 { 48 if(L <= l && R >= r) 49 { 50 return sum[rt]; 51 } 52 int m = (l+r)>>1; 53 int ret = 0; 54 if(L <= m) ret += query_sum(L, R, lson); 55 if(R > m) ret += query_sum(L, R, rson); 56 return ret; 57 } 58 int dfs1(int u, int pre, int deep) 59 { 60 siz[u] = 1; dep[u] = deep; fa[u] = pre; 61 int mmax = 0; 62 for(int i = 0; i < mp[u].size(); i++) 63 { 64 if(mp[u][i] != pre) 65 { 66 int temp = dfs1(mp[u][i], u, deep+1); 67 siz[u] += temp; 68 if(son[u] == -1 || temp >= mmax) 69 { 70 son[u] = mp[u][i]; 71 mmax = temp; 72 } 73 } 74 } 75 return siz[u]; 76 } 77 void dfs2(int u, int val) 78 { 79 top[u] = val; 80 if(son[u] != -1) 81 { 82 w[u] = ++pos; 83 fw[w[u]] = u; 84 dfs2(son[u], val); 85 } 86 else if(son[u] == -1) 87 { 88 w[u] = ++pos; 89 fw[w[u]] = u; 90 return; 91 } 92 for(int i = 0; i < mp[u].size(); i++) 93 { 94 if(mp[u][i] != son[u] && mp[u][i] != fa[u]) dfs2(mp[u][i], mp[u][i]); 95 } 96 } 97 int find_sum(int u, int v) 98 { 99 int f1 = top[u], f2 = top[v]; 100 int temp = 0; 101 while(f1 != f2) 102 { 103 if(dep[f1] < dep[f2]) 104 { 105 swap(f1, f2); 106 swap(u, v); 107 } 108 temp += query_sum(w[f1], w[u], 1, pos, 1); 109 u = fa[f1]; f1 = top[u]; 110 } 111 if(dep[u] > dep[v]) swap(u, v); 112 temp += query_sum(w[u], w[v], 1, pos, 1); 113 return temp; 114 } 115 int main() 116 { 117 // freopen("in.txt", "r", stdin); 118 int T; 119 scanf("%d", &T); 120 int cast = 0; 121 while(T--) 122 { 123 scanf("%d", &n); 124 cast++; 125 for(int i = 1; i <= n; i++) mp[i].clear(); 126 pos = 0; 127 memset(son, -1, sizeof(son)); 128 for(int i = 1; i <= n; i++) scanf("%d", &A[i]); 129 for(int i = 1; i <= n-1; i++) 130 { 131 int a, b; scanf("%d%d", &a, &b); 132 a++; b++; 133 mp[a].push_back(b); 134 mp[b].push_back(a); 135 } 136 dfs1(1, -1, 1); 137 dfs2(1, 1); 138 build(1, pos, 1); 139 140 scanf("%d", &q); 141 int op; 142 printf("Case %d: ", cast); 143 while(q--) 144 { 145 scanf("%d", &op); 146 int u, v; 147 scanf("%d%d", &u, &v); 148 if(op == 1) update(w[u+1], v, 1, pos, 1); 149 else if(op == 0) printf("%d ", find_sum(u+1, v+1)); 150 } 151 152 } 153 return 0; 154 }