题目描述:给出一棵 (n) 个点的树,点有颜色 (C_i),长度为 (m) 的数组 (V) 和长度为 (n) 的数组 (W)。有两种操作:
-
将 (C_x) 修改为 (y)。
-
求 (u) 到 (v) 的链的 (sumlimits_{i=1}^msumlimits_{j=1}^{cnt_i}W_j),其中 (cnt_i) 表示颜色 (i) 的出现次数。
数据范围:(n,mle 10^5,1le C_ile m),时限6s(洛谷)或8s(UOJ)。
这是树上带修莫队的模板题。
首先我们看树分块怎么做,所以要来先做这道题。
直接讲做法了:我们是尽可能在底层把大小 (ge B) 的联通块作为一块,剩下的扔给父亲合并。就是要开一个stack,维护当前还没有被分块的点。不停递归儿子,一旦已经有一个 (ge B) 的连通块了,就把它们作为一块,设首都为 (x)(当前dfs的点)。最后把 (x) 放进栈中。最后递归完还要把栈中剩下的点放入最后一个块,并把首都设为 (1)。
inline void dfs(int x, int f){
int t = top;
for(Rint i = head[x];i;i = nxt[i])
if(to[i] != f){
dfs(to[i], x);
if(top >= t + B){
rt[++ num] = x;
while(top > t) in[stk[top --]] = num;
}
}
stk[++ top] = x;
}
int main(){
// ...
dfs(1, 0);
if(!num) num = 1;
rt[num] = 1;
while(top) in[stk[top --]] = num;
// ...
}
我们知道没有合并,即大小 (<B),合并之后 (<2B)。就算最后一个连通块也至多有 (<3B),所以是比较均匀的。
然后我们看看如何从链 ((u,v)) 变为 ((u',v')) 并且时间复杂度为 (O( ext{len}(u,u')+ ext{len}(v,v')))。
首先我们改为维护 ((u,v)) 中抠掉 (lca) 的答案,(cnt) 和每个点是否在里面的 (vis)。设 (L(u,v)) 为 (u) 到 (v) 这条链上抠掉 (lca) 的点集,(oplus) 为集合对称差((Aoplus B=(Acup B)-(Acap B))),(S(u)) 为 (1) 到 (u) 的这条链的点集。则 (L(u,v)=S(u)oplus S(v)),且集合对称差肯定是有交换、结合律的。
[egin{aligned}
L(u',v')&=L(u',v')oplus L(u,v)oplus L(u,v) \
&=S(u')oplus S(v')oplus S(u)oplus S(v)oplus S(u) oplus S(v) \
&=(S(u')oplus S(u))oplus(S(v)oplus S(v'))oplus(S(u)oplus S(v)) \
&=L(u,u')oplus L(v,v')oplus L(u,v)
end{aligned}
]
于是就直接是将 (L(u,u')) 和 (L(v,v')) 全部 (vis) 取反就是 (L(u',v')),然后把 (lca) 取反就是 (u') 到 (v') 这条链。
至于带修莫队怎么做,就去看这道题。取 (B) 比 (n^frac{2}{3}) 少一点点就可以,时间复杂度 (O(n^frac{5}{3}+nlog n))。
code
```cpp
#include
#define Rint register int
using namespace std;
typedef long long LL;
const int N = 100003;
int n, m, B, q, ql, qr, qnow, qnum, cnum, V[N], W[N], C[N], head[N], to[N << 1], nxt[N << 1], dfn[N], cnt[N];
bool vis[N];
LL ans[N], qans;
inline void add(int a, int b){
static int cnt = 0;
to[++ cnt] = b; nxt[cnt] = head[a]; head[a] = cnt;
}
int dep[N], top[N], fa[N], siz[N], wson[N], stk[N], tp, bnum, bel[N];
inline void dfs1(int x){
int tmp = tp;
siz[x] = 1;
for(Rint i = head[x];i;i = nxt[i])
if(to[i] != fa[x]){
fa[to[i]] = x; dep[to[i]] = dep[x] + 1;
dfs1(to[i]);
siz[x] += siz[to[i]];
if(siz[to[i]] > siz[wson[x]]) wson[x] = to[i];
if(tp >= tmp + B){
++ bnum;
while(tp > tmp) bel[stk[tp --]] = bnum;
}
}
stk[++ tp] = x;
}
inline void dfs2(int x, int topf){
top[x] = topf;
if(wson[x]) dfs2(wson[x], topf);
for(Rint i = head[x];i;i = nxt[i])
if(to[i] != wson[x] && to[i] != fa[x])
dfs2(to[i], to[i]);
}
inline int lca(int u, int v){
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u, v);
u = fa[top[u]];
}
return dep[u] < dep[v] ? u : v;
}
struct Query {
int u, v, id, tim;
inline bool operator < (const Query &o) const {
if(bel[u] != bel[o.u]) return bel[u] < bel[o.u];
if(bel[v] != bel[o.v]) return bel[v] < bel[o.v];
return tim < o.tim;
}
} que[N];
struct Change {
int u, val;
} cha[N];
inline void work(int x){
if(vis[x]) qans -= (LL) V[C[x]] * W[cnt[C[x]]], -- cnt[C[x]];
else ++ cnt[C[x]], qans += (LL) V[C[x]] * W[cnt[C[x]]];
vis[x] ^= 1;
}
inline void workpath(int u, int v){
if(dep[u] < dep[v]) swap(u, v);
while(dep[u] > dep[v]){work(u); u = fa[u];}
while(u != v){work(u); u = fa[u]; work(v); v = fa[v];}
}
inline void change(int i){
int u = cha[i].u;
if(vis[u]){
work(u); swap(C[u], cha[i].val); work(u);
} else swap(C[u], cha[i].val);
}
int main(){
scanf("%d%d%d", &n, &m, &q); B = pow(n, 2.0 / 3);
for(Rint i = 1;i <= m;i ++) scanf("%d", V + i);
for(Rint i = 1;i <= n;i ++) scanf("%d", W + i);
for(Rint i = 1;i < n;i ++){
int a, b;
scanf("%d%d", &a, &b);
add(a, b); add(b, a);
}
dfs1(1);
if(!bnum) bnum = 1;
while(tp) bel[stk[tp --]] = bnum;
dfs2(1, 1);
for(Rint i = 1;i <= n;i ++) scanf("%d", C + i);
for(Rint i = 1;i <= q;i ++){
int opt, x, y;
scanf("%d%d%d", &opt, &x, &y);
if(opt == 0) cha[++ cnum] = (Change){x, y};
else ++ qnum, que[qnum] = (Query){x, y, qnum, cnum};
}
sort(que + 1, que + qnum + 1);
ql = qr = 1; qnow = 0;
for(Rint i = 1;i <= qnum;i ++){
int tl = que[i].u, tr = que[i].v;
workpath(ql, tl); workpath(qr, tr);
ql = tl; qr = tr;
while(qnow < que[i].tim) change(++ qnow);
while(qnow > que[i].tim) change(qnow --);
int LCA = lca(tl, tr);
work(LCA); ans[que[i].id] = qans; work(LCA);
}
for(Rint i = 1;i <= qnum;i ++) printf("%lld
", ans[i]);
}
```