简单的题面
给定一棵以1为根的有根树,点可能是黑色或白色,操作如下。
1. 选定一个点x,将x的子树中所有到x的距离为奇数的点的颜色反转。
2. 选定一个点x,将点x的颜色反转。
3. 选定一个点x,询问所有黑点y(包括点x)与点x的lca(最近公共祖先)的和。
果然自己码一码收获挺大的....
首先考虑怎么回答3操作,不妨考虑枚举$lca$
如果$lca = x$,可以发现,在$x$子树内的答案都是$x$
否则,根节点到$x$形成了一条链,令其为$1 o x_1 o x_2 ...... o x$
可以发现,$x_i$对答案的贡献为$(sz[x_i] - sz[x_{i + 1}]) * x_i$
由于答案分布在一条链上,考虑使用轻重链剖分
考虑到2操作和1操作
用线段树动态的维护
$sz[i][2][2], col[i]$
分别表示
1.$i$节点子树中,深度为奇 / 偶数,颜色为 黑 / 白的节点数
2.$i$节点的颜色
以及
$sum[2][2]$表示区间内所有$g[i][2][2]$的和
其中$g[i][x][y] = sz[i][x][y] * i$
还有一大堆的细节,直接看代码吧,无法用言语来描述
$51nod ;rank1$,$hhh$
还是有$8kb$......应该还能短点的
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; extern inline char gc() { static char RR[23456], *S = RR + 23333, *T = RR + 23333; if(S == T) fread(RR, 1, 23333, stdin), S = RR; return *S ++; } inline int read() { int p = 0, w = 1; char c = gc(); while(c > '9' || c < '0') { if(c == '-') w = -1; c = gc(); } while(c >= '0' && c <= '9') p = p * 10 + c - '0', c = gc(); return p * w; } int wr[50], rw; char WR[40000005], *I = WR; #define pc(z) *I ++ = (z) template <typename re> inline void write(re x) { if(!x) pc('0'); if(x < 0) pc('-'), x = -x; while(x) wr[++ rw] = x % 10, x /= 10; while(rw) pc(wr[rw --] + '0'); pc(' '); } #define fe float #define de double #define le long double #define ll long long #define ui unsigned int #define ri register int #define ull unsigned long long #define sid 200050 #define eid 400050 int n, m, cnp, did; int dfn[sid], ord[sid], anc[sid], col[sid]; int cap[sid], node[eid], nxt[eid]; int sz[sid], dep[sid], pre[sid], fa[sid]; inline void adeg(int u, int v) { nxt[++ cnp] = cap[u]; cap[u] = cnp; node[cnp] = v; } #define cur node[i] inline void dfs(int o) { sz[o] = 1; for(int i = cap[o]; i; i = nxt[i]) if(cur != fa[o]) { fa[cur] = o; dep[cur] = dep[o] + 1; dfs(cur); sz[o] += sz[cur]; if(sz[pre[o]] < sz[cur]) pre[o] = cur; } } inline void dfs(int o, int tp) { anc[o] = tp; dfn[++ did] = o; ord[o] = did; if(pre[o]) dfs(pre[o], tp); else return; for(int i = cap[o]; i; i = nxt[i]) if(cur != fa[o] && cur != pre[o]) dfs(cur, cur); } int f[sid][2][2], g[sid][2][2]; inline void dp(int o) { f[o][dep[o] & 1][col[o]] = 1; for(int i = cap[o]; i; i = nxt[i]) if(cur != fa[o]) { dp(cur); for(ri d = 0; d <= 1; d ++) for(ri c = 0; c <= 1; c ++) f[o][d][c] += f[cur][d][c]; } for(ri d = 0; d <= 1; d ++) for(ri c = 0; c <= 1; c ++) g[o][d][c] = f[o][d][c] - f[pre[o]][d][c]; } struct Seg { ll s[2][2]; int rev[2], mas[2][2]; } t[sid * 4]; #define ls (o << 1) #define rs (o << 1 | 1) inline void update(int o) { for(ri i = 0; i <= 1; i ++) for(ri j = 0; j <= 1; j ++) t[o].s[i][j] = t[ls].s[i][j] + t[rs].s[i][j]; } inline void build(int o, int l, int r) { if(l == r) { int x = dfn[l]; for(ri i = 0; i <= 1; i ++) for(ri j = 0; j <= 1; j ++) t[o].s[i][j] = 1ll * g[x][i][j] * x; return; } int mid = (l + r) >> 1; build(ls, l, mid); build(rs, mid + 1, r); update(o); } inline void prev(int o, int d, int l, int r) { if(l == r) { int x = dfn[l]; if((dep[x] & 1) == d) col[x] ^= 1; swap(f[x][d][0], f[x][d][1]); } swap(t[o].s[d][0], t[o].s[d][1]); swap(t[o].mas[d][0], t[o].mas[d][1]); t[o].rev[d] ^= 1; } inline void premas(int o, int d, int c, int v, int l, int r) { if(l == r) { int x = dfn[l]; f[x][d][c] -= v; f[x][d][c ^ 1] += v; } t[o].mas[d][c] += v; } inline void pushdown(int o, int l, int r) { int mid = (l + r) >> 1; for(ri i = 0; i <= 1; i ++) if(t[o].rev[i]) { t[o].rev[i] = 0; prev(ls, i, l, mid); prev(rs, i, mid + 1, r); } for(ri i = 0; i <= 1; i ++) for(ri j = 0; j <= 1; j ++) if(t[o].mas[i][j] != 0) { premas(ls, i, j, t[o].mas[i][j], l, mid); premas(rs, i, j, t[o].mas[i][j], mid + 1, r); t[o].mas[i][j] = 0; } } inline void rev(int o, int l, int r, int ml, int mr, int d) { if(ml > mr) return; if(ml > r || mr < l) return; if(ml <= l && mr >= r) { prev(o, d, l, r); return; } int mid = (l + r) >> 1; pushdown(o, l, r); rev(ls, l, mid, ml, mr, d); rev(rs, mid + 1, r, ml, mr, d); update(o); } inline void mas(int o, int l, int r, int ml, int mr, int d, int c, int v) { if(ml > mr) return; if(ml > r || mr < l) return; if(ml <= l && mr >= r) { premas(o, d, c, v, l, r); return; } int mid = (l + r) >> 1; pushdown(o, l, r); mas(ls, l, mid, ml, mr, d, c, v); mas(rs, mid + 1, r, ml, mr, d, c, v); update(o); } inline void mis(int o, int l, int r, int p, int d, int c, int v) { if(l == r) { f[p][d][c ^ 1] += v; f[p][d][c] -= v; t[o].s[d][c ^ 1] += 1ll * v * p; t[o].s[d][c] -= 1ll * v * p; return; } int mid = (l + r) >> 1; pushdown(o, l, r); if(ord[p] <= mid) mis(ls, l, mid, p, d, c, v); else mis(rs, mid + 1, r, p, d, c, v); update(o); } inline int qc(int o, int l, int r, int v) { if(l == r) return col[v]; int mid = (l + r) >> 1; pushdown(o, l, r); if(ord[v] <= mid) return qc(ls, l, mid, v); else return qc(rs, mid + 1, r, v); } inline int dsz(int o, int l, int r, int v, int d) { if(l == r) return f[v][d][1] - f[v][d][0]; int mid = (l + r) >> 1; pushdown(o, l, r); if(ord[v] <= mid) return dsz(ls, l, mid, v, d); else return dsz(rs, mid + 1, r, v, d); } inline int qsz(int o, int l, int r, int v) { if(l == r) return f[v][0][1] + f[v][1][1]; int mid = (l + r) >> 1; pushdown(o, l, r); if(ord[v] <= mid) return qsz(ls, l, mid, v); else return qsz(rs, mid + 1, r, v); } inline ll qs(int o, int l, int r, int ml, int mr) { if(ml > mr) return 0; if(ml > r || mr < l) return 0; if(ml <= l && mr >= r) return t[o].s[0][1] + t[o].s[1][1]; int mid = (l + r) >> 1; pushdown(o, l, r); return qs(ls, l, mid, ml, mr) + qs(rs, mid + 1, r, ml, mr); } inline void change(int x) { int d = (dep[x] + 1) & 1, ff = anc[x]; int der = dsz(1, 1, n, x, d); rev(1, 1, n, ord[x], ord[x] + sz[x] - 1, d); mas(1, 1, n, ord[ff], ord[x] - 1, d, 1, der); for(ri i = anc[fa[ff]], j = fa[ff]; j; j = fa[i], i = anc[j]) mis(1, 1, n, j, d, 1, der), mas(1, 1, n, ord[i], ord[j] - 1, d, 1, der); } inline void put(int x) { int ff = anc[x]; col[x] = qc(1, 1, n, x); int d = dep[x] & 1, c = col[x]; mis(1, 1, n, x, d, c, 1); mas(1, 1, n, ord[ff], ord[x] - 1, d, c, 1); for(ri i = anc[fa[ff]], j = fa[ff]; j; j = fa[i], i = anc[j]) mis(1, 1, n, j, d, c, 1), mas(1, 1, n, ord[i], ord[j] - 1, d, c, 1); col[x] ^= 1; } inline ll query(int x) { ll ans = 0; int f = anc[x]; ans += 1ll * qsz(1, 1, n, x) * x; for(ri i = f, j = x, o = x; j; j = fa[i], o = i, i = anc[j]) { if(j != o) ans += 1ll * (qsz(1, 1, n, j) - qsz(1, 1, n, o)) * j; ans += qs(1, 1, n, ord[i], ord[j] - 1); } return ans; } int main() { n = read(); m = read(); for(ri i = 1; i <= n; i ++) col[i] = read(); for(ri i = 1; i < n; i ++) { int u = read(), v = read(); adeg(u, v); adeg(v, u); } dfs(1); dfs(1, 1); dp(1); build(1, 1, n); for(ri i = 1; i <= m; i ++) { int opt = read(), x = read(); if(opt == 1) change(x); if(opt == 2) put(x); if(opt == 3) write(query(x)); } fwrite(WR, 1, I - WR, stdout); return 0; }