给你一棵 $n$ 个点的树,边有边权。$m$ 次询问,每次给出 $l$ 、$r$ 、$x$ ,求 $ ext{Min}_{i=l}^r ext{dis}(i,x)$ 。
$n,mle 10^5$ 。
题解
动态点分治+线段树
分块做法太傻逼了我们把它丢到垃圾桶里。树上距离考虑动态点分治。
求出这棵树的点分树,对每一棵点分树子树开一棵动态开点编号线段树,维护编号在某区间内的点到当前点距离的最大值。
对于一次查询,我们在点分树从 $x$ 到根的路径上所有点对应的线段树上查询 $[l,r]$ 的最大值,$dis(i,x)+query(l,r,root_i)$ 的最大值极为答案。
这样做的正确性比较显然:
1. 每个 $[l,r]$ 内的点都属于这些子树的一个部分内,都被正确统计了一次。
2. 多余统计时,距离只会统计大,不会统计小,没有影响。
时间复杂度 $O(nlog^2 n)$
#include <cstdio> #include <cstring> #include <algorithm> #define N 100010 #define inf 1 << 30 using namespace std; int head[N] , to[N << 1] , len[N << 1] , next[N << 1] , cnt , deep[N] , pos[N] , md[N << 1][20] , log[N << 1] , tot , si[N] , ms[N] , sum , root , vis[N] , fa[N]; int ls[N * 300] , rs[N * 300] , mn[N * 300] , rt[N] , tp; inline void add(int x , int y , int z) { to[++cnt] = y , len[cnt] = z , next[cnt] = head[x] , head[x] = cnt; } void dfs(int x , int pre) { int i; md[++tot][0] = deep[x] , pos[x] = tot; for(i = head[x] ; i ; i = next[i]) if(to[i] != pre) deep[to[i]] = deep[x] + len[i] , dfs(to[i] , x) , md[++tot][0] = deep[x]; } inline int dis(int x , int y) { int t = deep[x] + deep[y] , k; x = pos[x] , y = pos[y]; if(x > y) swap(x , y); k = log[y - x + 1]; return t - 2 * min(md[x][k] , md[y - (1 << k) + 1][k]); } void getroot(int x , int pre) { int i; si[x] = 1 , ms[x] = 0; for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]] && to[i] != pre) getroot(to[i] , x) , si[x] += si[to[i]] , ms[x] = max(ms[x] , si[to[i]]); ms[x] = max(ms[x] , sum - si[x]); if(ms[x] < ms[root]) root = x; } void solve(int x) { int i; vis[x] = 1; for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]]) sum = si[to[i]] , root = 0 , getroot(to[i] , 0) , fa[root] = x , solve(root); } void update(int p , int a , int l , int r , int &x) { if(!x) x = ++tp , mn[x] = inf; mn[x] = min(mn[x] , a); if(l == r) return; int mid = (l + r) >> 1; if(p <= mid) update(p , a , l , mid , ls[x]); else update(p , a , mid + 1 , r , rs[x]); } int query(int b , int e , int l , int r , int x) { if(!x) return inf; if(b <= l && r <= e) return mn[x]; int mid = (l + r) >> 1 , ans = inf; if(b <= mid) ans = min(ans , query(b , e , l , mid , ls[x])); if(e > mid) ans = min(ans , query(b , e , mid + 1 , r , rs[x])); return ans; } int main() { int n , m , i , j , x , y , z , ans; scanf("%d" , &n); for(i = 1 ; i < n ; i ++ ) scanf("%d%d%d" , &x , &y , &z) , add(x , y , z) , add(y , x , z); dfs(1 , 0); for(i = 2 ; i <= tot ; i ++ ) log[i] = log[i >> 1] + 1; for(i = 1 ; i <= log[tot] ; i ++ ) for(j = 1 ; j <= tot - (1 << i) + 1 ; j ++ ) md[j][i] = min(md[j][i - 1] , md[j + (1 << (i - 1))][i - 1]); ms[0] = sum = n , root = 0 , getroot(1 , 0) , solve(root); for(i = 1 ; i <= n ; i ++ ) for(j = i ; j ; j = fa[j]) update(i , dis(i , j) , 1 , n , rt[j]); scanf("%d" , &m); while(m -- ) { scanf("%d%d%d" , &x , &y , &z) , ans = inf; for(i = z ; i ; i = fa[i]) ans = min(ans , dis(i , z) + query(x , y , 1 , n , rt[i])); printf("%d " , ans); } return 0; }