HDU-6110 路径交 线段树维护区间交集 LCA求树上路径交
题意
给定一颗(n)个点的带边权的树,以及其中的(m)条路径,每次询问其中的第(L)条到第(R)条路径的交集长度
[nleq 500,000\
mle 500,000\
Qleq 500,000
]
分析
容易想到区间交集可以用线段树维护
所以问题其实就是给定两条树上的路径,如何求出他们的交集长度
我们有结论:对于两条树上路径(a->b,c->d) 他们的交集就是(lca(a,c),lca(a,d),lca(b,c),lca(b,d))中深度较大的两点
这样就可以用线段树进行区间合并
对于树上两点的距离则可以用公式计算:
[dis(a,b) = dis(a) + dis(b) - 2 imes dis(lca(a,b))
]
其中(dis(a))表示根到(a)的距离
代码
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 5;
typedef long long ll;
#define pii pair<int,int>
#define fi first
#define se second
int readint(){
int x = 0;
int f = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9'){
x = x * 10 + ch - '0';
ch = getchar();
}
return x * f;
}
int fa[maxn],son[maxn],dep[maxn],siz[maxn];
int pre[maxn],id[maxn],top[maxn];
int tt;
int d[maxn];
map<pii,int> mp;
vector<int> e[maxn];
void dfs1(int u){
siz[u] = 1;
for(auto v:e[u]){
if(v == fa[u]) continue;
dep[v] = dep[u] + 1;
fa[v] = u;
d[v] = d[u] + mp[make_pair(u,v)];
dfs1(v);
siz[u] += siz[v];
if(siz[v] > siz[son[u]])
son[u] = v;
}
}
void dfs2(int u,int x){
pre[u] = ++tt;
id[tt] = u;
top[u] = x;
if(!son[u]) return;
dfs2(son[u],x);
for(auto v:e[u]){
if(v == fa[u] || v == son[u]) continue;
dfs2(v,v);
}
}
int lca(int x,int y){
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]])
x = fa[top[x]];
else y = fa[top[y]];
}
return dep[x] > dep[y] ? y : x;
}
int dis(int x,int y){
return d[x] + d[y] - 2 * d[lca(x,y)];
}
bool cmp(int a,int b){
return dep[a] > dep[b];
}
int aa[10];
pii merge(pii a,pii b){
aa[0] = lca(a.fi,b.fi);
aa[1] = lca(a.fi,b.se);
aa[2] = lca(a.se,b.fi);
aa[3] = lca(a.se,b.se);
sort(aa,aa + 4,cmp);
return make_pair(aa[0],aa[1]);
}
pii node[maxn << 2];
void build(int i,int l,int r){
if(l == r){
node[i].fi = readint();
node[i].se = readint();
return;
}
int mid = l + r >> 1;
build(i << 1,l,mid);
build(i << 1|1,mid + 1,r);
node[i] = merge(node[i << 1],node[i << 1|1]);
}
pii query(int i,int l,int r,int L,int R){
if(l == L && r == R) return node[i];
int mid = l + r >> 1;
if(R <= mid) return query(i << 1,l,mid,L,R);
else if(L > mid) return query(i << 1|1,mid + 1,r,L,R);
else return merge(query(i << 1,l,mid,L,mid),query(i << 1|1,mid + 1,r,mid + 1,R));
}
int main(){
int n = readint();
for(int i = 1;i < n;i++){
int x = readint();
int y = readint();
int z = readint();
e[x].push_back(y);
e[y].push_back(x);
mp[make_pair(x,y)] = z;
mp[make_pair(y,x)] = z;
}
int m = readint();
dfs1(1);
dfs2(1,1);
build(1,1,m);
int q = readint();
while(q--){
int l = readint();
int r = readint();
pii res = query(1,1,m,l,r);
printf("%d
",dis(res.fi,res.se));
}
}