题目链接:洛谷
又来做Ynoi里面的水题了。。。
首先换根的话是一个套路,首先以1为根dfs,然后画一画就知道以rt为根,x的子树是什么了。可以拆分为2个dfs连续段。
然后如果要计算([l_1,r_1])与([l_2,r_2])的答案,那么就是那么做一个二维差分就可以改成([1,r_1])与([1,r_2])的答案了。用((r_1,r_2))做莫队就可以过了。
注意有一点,要去除那些不必要的询问,即(r_1=0)或者(r_2=0),这样就可以去掉大量的询问,不然会T掉3个点。
#include<bits/stdc++.h>
#define Rint register int
using namespace std;
typedef long long LL;
const int N = 100003;
int n, m, tot, len, blo, rt, head[N], to[N << 1], nxt[N << 1], a[N], b[N], dfn[N], pre[N], dep[N], fa[N], siz[N], wson[N], top[N], tim;
LL ans[N * 5];
inline void add(int a, int b){
static int cnt = 0;
to[++ cnt] = b; nxt[cnt] = head[a]; head[a] = cnt;
}
inline void dfs1(int x){
siz[x] = 1;
for(Rint i = head[x];i;i = nxt[i])
if(to[i] != fa[x]){
dep[to[i]] = dep[x] + 1; fa[to[i]] = x;
dfs1(to[i]);
siz[x] += siz[to[i]];
if(siz[to[i]] > siz[wson[x]]) wson[x] = to[i];
}
}
inline void dfs2(int x, int tp){
top[x] = tp; dfn[x] = ++ tim; pre[tim] = x;
if(wson[x]){
dfs2(wson[x], tp);
for(Rint i = head[x];i;i = nxt[i])
if(to[i] != fa[x] && to[i] != wson[x])
dfs2(to[i], to[i]);
}
}
inline int calc(int x, int y){
while(dep[x] > dep[y]){
if(dep[top[x]] <= dep[y]) return wson[y];
if(fa[top[x]] == y) return top[x];
x = fa[top[x]];
}
return x;
}
struct Query {
int l, r, id;
bool flag;
inline bool operator < (const Query &o) const {
if(l / blo != o.l / blo) return l / blo < o.l / blo;
if((l / blo) & 1) return r > o.r;
return r < o.r;
}
} que[N * 80];
inline void add(int l1, int r1, int l2, int r2, int id){
if(r1 && r2){que[++ tot].l = r1; que[tot].r = r2; que[tot].id = id; que[tot].flag = false;}
if(l1 > 1 && r2){que[++ tot].l = l1 - 1; que[tot].r = r2; que[tot].id = id; que[tot].flag = true;}
if(r1 && l2 > 1){que[++ tot].l = r1; que[tot].r = l2 - 1; que[tot].id = id; que[tot].flag = true;}
if(l1 > 1 && l2 > 1){que[++ tot].l = l1 - 1; que[tot].r = l2 - 1; que[tot].id = id; que[tot].flag = false;}
}
int ql = 0, qr = 0, cnt1[N], cnt2[N];
LL qans = 0;
inline void add1(int x){++ cnt1[x]; qans += cnt2[x];}
inline void del1(int x){-- cnt1[x]; qans -= cnt2[x];}
inline void add2(int x){++ cnt2[x]; qans += cnt1[x];}
inline void del2(int x){-- cnt2[x]; qans -= cnt1[x];}
int main(){
scanf("%d%d", &n, &m); blo = sqrt(n);
for(Rint i = 1;i <= n;i ++) scanf("%d", a + i), b[i] = a[i];
sort(b + 1, b + n + 1);
len = unique(b + 1, b + n + 1) - b - 1;
for(Rint i = 1;i <= n;i ++) a[i] = lower_bound(b + 1, b + len + 1, a[i]) - b;
for(Rint i = 1;i < n;i ++){
int a, b; scanf("%d%d", &a, &b); add(a, b); add(b, a);
}
dfs1(1); dfs2(1, 1); rt = 1;
int pos = 0;
while(m --){
int opt, x, y;
scanf("%d", &opt);
if(opt == 1) scanf("%d", &rt);
else {
scanf("%d%d", &x, &y); ++ pos;
int l1[2], r1[2], l2[2], r2[2], cnt1 = 0, cnt2 = 0;
if(rt == x) l1[0] = 1, r1[0] = n, cnt1 = 1;
else if(dfn[rt] > dfn[x] && dfn[rt] < dfn[x] + siz[x]){
int tmp = calc(rt, x);
l1[0] = 1; r1[0] = dfn[tmp] - 1; l1[1] = dfn[tmp] + siz[tmp]; r1[1] = n; cnt1 = 2;
} else l1[0] = dfn[x], r1[0] = dfn[x] + siz[x] - 1, cnt1 = 1;
if(rt == y) l2[0] = 1, r2[0] = n, cnt2 = 1;
else if(dfn[rt] > dfn[y] && dfn[rt] < dfn[y] + siz[y]){
int tmp = calc(rt, y);
l2[0] = 1; r2[0] = dfn[tmp] - 1; l2[1] = dfn[tmp] + siz[tmp]; r2[1] = n; cnt2 = 2;
} else l2[0] = dfn[y], r2[0] = dfn[y] + siz[y] - 1, cnt2 = 1;
for(Rint i = 0;i < cnt1;i ++)
for(Rint j = 0;j < cnt2;j ++)
add(l1[i], r1[i], l2[j], r2[j], pos);
}
}
for(Rint i = 1;i <= tot;i ++)
if(que[i].l > que[i].r) swap(que[i].l, que[i].r);
sort(que + 1, que + tot + 1);
for(Rint i = 1;i <= tot;i ++){
while(ql < que[i].l) add1(a[pre[++ ql]]);
while(ql > que[i].l) del1(a[pre[ql --]]);
while(qr < que[i].r) add2(a[pre[++ qr]]);
while(qr > que[i].r) del2(a[pre[qr --]]);
if(que[i].flag) ans[que[i].id] -= qans;
else ans[que[i].id] += qans;
}
for(Rint i = 1;i <= pos;i ++) printf("%lld
", ans[i]);
}