- 题意: 给一棵树,每次查询u到v路径上有多少不同的点权
首先需要证明这类题目符合区间加减性质
摘选一段vfk大牛的证明
用S(v, u)代表 v到u的路径上的结点的集合。
用root来代表根结点,用lca(v, u)来代表v、u的最近公共祖先。
那么S(v, u) = S(root, v) xor S(root, u) xor lca(v, u)
其中xor是集合的对称差。
简单来说就是节点出现两次消掉。
lca很讨厌,于是再定义
T(v, u) = S(root, v) xor S(root, u)
观察将curV移动到targetV前后T(curV, curU)变化:
T(curV, curU) = S(root, curV) xor S(root, curU)
T(targetV, curU) = S(root, targetV) xor S(root, curU)
取对称差:
T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
由于对称差的交换律、结合律:
T(curV, curU) xor T(targetV, curU)= S(root, curV) xorS(root, targetV)
两边同时xor T(curV, curU):
T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
发现最后两项很爽……哇哈哈
T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
(有公式恐惧症的不要走啊 T_T)
也就是说,更新的时候,xor T(curV, targetV)就行了。
即,对curV到targetV路径(除开lca(curV, targetV))上的结点,将它们的存在性取反即可
其实也就是说对于先前位置(preu,prev) 和当前位置 已知T(preu,prev) 可以很方便的计算T(curu,curv)
具体做法如下
求出T(preu,curu) 和 T(prev,curv) 暴力遍历两点之间路径即可
此时T(curu,curv) = T(preu,prev) xor T(preu,curu) xor T(prev,curv)
S(curu,curv) = T(curu,curv) xot lca(curu,curv)
下面考虑如何对树上序列进行分块了
有两种方法
- 第一种方法:dfs的时候对每个点记录 进栈时间戳f(x) 和 出栈时间戳g(x),得到一个2n的序列
- 对于查询(x,y) 令f(x) < f(y)
- 如果 x 是 y 的祖先,考虑从x向下走到y 即区间[f(x) , f(y)]
显然除了x到y路径上的点 之外 其他在区间[f(x),f(y)]出现的点都出现了两次- 如果x 不是 y 的祖先,那么必然是先往上走 再往下,即区间[g(x),f(y)] 再加上lca(x,y)
- 第二种方法: 考虑对树上关键点的划分,详情见分块有关论文,证明我也没太看懂,大概的理解就是把一些距离相近的点划分成一块,减少块与块之间需要跨越的距离。
第一种方法 序列长度为2n 看起来常数似乎比第二种要大,而且每个点记录两次处理起来麻烦一点,所以我用的是第二种
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
int col[N],vis[N],cnt[N];
int pos[N],dfn[N];
int head[N],EN,tot;
int n , m, Siz;
inline void read(int &x){
char c = getchar();
x = 0;
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - 48,c = getchar();
}
struct Q{
int l , r, id, b;
int x , y;
Q(){};
bool operator < (const Q&rhs)const{
if(b == rhs.b) return dfn[r] < dfn[rhs.r];///左端点先按块排序,再右端点按时间戳排序
return b < rhs.b;
}
}q[N];
int ans[N] , res;
struct edge{
int v,nxt;
edge(){};
edge(int v,int nxt):v(v),nxt(nxt){};
}e[N];
void add(int u,int v){
e[EN] = edge(v,head[u]);
head[u] = EN++;
}
int f[N][22],dep[N];
int lca(int u,int v){
if(dep[u] < dep[v]) swap(u,v);
int d = dep[u] - dep[v];
for(int i = 0;i <= 20;i++)
if((1<<i) & d) u = f[u][i];
if(u == v) return u;
for(int i = 20;i >= 0;i--)
if(f[u][i] != f[v][i]) u = f[u][i],v = f[v][i];
return f[u][0];
}
int stk[N],top,b_cnt;
int dfs(int u,int fa,int d){
for(int i = 1;i <= 20;i++) f[u][i] = f[f[u][i-1]][i-1];
dep[u] = d;
dfn[u] = tot++;
int siz = 0;
for(int i = head[u];~i;i = e[i].nxt){
if(i == fa) continue;
int v = e[i].v;
f[v][0] = u;
siz += dfs(v,i ^ 1,d + 1);
if(siz >= Siz){
while(siz--) pos[stk[top--]] = b_cnt;
b_cnt++;
}
}
stk[++top] = u;
return siz + 1;
}
void init(){
memset(head,-1,sizeof(head));
b_cnt = top = EN = tot = 0;
}
inline void up(int u){
if(!vis[u]) {
if(++cnt[col[u]] == 1) res++;
vis[u] = 1;
}else{
vis[u] = 0;
if(--cnt[col[u]] == 0) res--;
}
}
inline void work(int u,int v){
while(u != v){
if(dep[u] < dep[v]) swap(u,v);
up(u),u = f[u][0];
}
}
map<int,int> mp;
int ID;
/*
8 2
1000000000000 2 9 3 8 5 1000001 1000001
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
7 8
*/
int main(){
read(n),read(m);
ID = 1;mp.clear();
for(int i = 1;i <= n;i++) {
read(col[i]);
if(!mp[col[i]]) mp[col[i]] = ID++;
col[i] = mp[col[i]];
}
init();
int rt = 0;
for(int i = 1;i < n;i++){
int u , v;
read(u),read(v);
if(!u) rt = v;
else if(!v) rt = u;
else add(u,v),add(v,u);
}
Siz = sqrt(n + 0.5);
dfs(1,-1,1);
while(top) pos[stk[top--]] = b_cnt;
for(int i = 0;i < m;i++){
read(q[i].l),read(q[i].r);//read(q[i].x),read(q[i].y);
if(dfn[q[i].l] > dfn[q[i].r]) swap(q[i].l,q[i].r);
q[i].id = i;
q[i].b = pos[q[i].l];
}
sort(q,q + m);
memset(vis,0,sizeof(vis));
memset(cnt,0,sizeof(cnt));
res = 0;
int LCA = lca(q[0].l,q[0].r);
work(q[0].l,q[0].r);
up(LCA);
ans[q[0].id] = res;
up(LCA);
for(int i = 1;i < m;i++){
work(q[i-1].l,q[i].l);
work(q[i-1].r,q[i].r);
LCA = lca(q[i].l,q[i].r);
up(LCA);
ans[q[i].id] = res;
up(LCA);
}
for(int i = 0;i < m;i++) printf("%d
",ans[i]);
return 0;
}