JZOJ 6866. 【2020.11.16提高组模拟】路径大小差
题目大意
问树上有多少点对之间路径边权
m
a
x
−
m
i
n
=
k
max-min=k
m a x − m i n = k ,
k
k
k 为定值。
k
≤
n
≤
2
∗
1
0
5
kleq nleq2*10^5
k ≤ n ≤ 2 ∗ 1 0 5 .
题解
其实这题比较套路,并不难想。 关于树上路径计数的问题,一般先考虑点分治能不能实现,发现是可以的。 按照一般点分治的套路,找到某个子树重心后,记录每个点到它的路径边权
m
a
x
,
m
i
n
max,min
m a x , m i n ,有两种情况,一种是重心为路径的一端,直接枚举判断;另一种是重心在路径中间。 第二种情况,按
m
a
x
max
m a x 从小到大排序,枚举一条路径和前面的另一条组合, 因为已经排好序了,所以
m
a
x
max
m a x 一定在当前这条路径上,接着再分两种情况,一种是该路径的
m
a
x
−
m
i
n
<
k
max-min<k
m a x − m i n < k ,那么查找前面
m
i
n
=
m
a
x
−
k
min=max-k
m i n = m a x − k 的数量加入答案;一种是该路径的
m
a
x
−
m
i
n
=
k
max-min=k
m a x − m i n = k ,则查找前面
m
i
n
≥
m
a
x
−
k
mingeq max-k
m i n ≥ m a x − k 的数量加入答案。用树状数组维护。 但是会发现组合的两条路径可能出现在当前根的同一子树中,那么把每棵子树的路径单独求一遍,从答案中减去即可。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define ll long long
#define N 200010
int n, K;
ll ans = 0 ;
int last[ N] , nxt[ N * 2 ] , to[ N * 2 ] , we[ N * 2 ] , len = 0 ;
int vi[ N] , si[ N] , sum[ N] , s, rt, mi;
int tot = 0 , f[ N] ;
struct node {
int mx, mi, r;
} a[ N] ;
void add ( int x, int y, int w) {
to[ ++ len] = y;
we[ len] = w;
nxt[ len] = last[ x] ;
last[ x] = len;
}
void dfs ( int k, int fa) {
si[ k] = 1 ;
for ( int i = last[ k] ; i; i = nxt[ i] ) if ( to[ i] != fa && ! vi[ to[ i] ] ) {
dfs ( to[ i] , k) ;
si[ k] + = si[ to[ i] ] ;
}
}
void find ( int k, int fa) {
int mx = s - si[ k] ;
for ( int i = last[ k] ; i; i = nxt[ i] ) if ( to[ i] != fa && ! vi[ to[ i] ] ) {
find ( to[ i] , k) ;
mx = max ( mx, si[ to[ i] ] ) ;
}
if ( mx < mi) mi = mx, rt = k;
}
void dfs1 ( int k, int fa, int t0, int t1, int r) {
if ( t1) a[ ++ tot] . mx = t1, a[ tot] . mi = t0, a[ tot] . r = r;
for ( int i = last[ k] ; i; i = nxt[ i] ) if ( to[ i] != fa && ! vi[ to[ i] ] ) {
dfs1 ( to[ i] , k, min ( t0, we[ i] ) , max ( t1, we[ i] ) , r == 0 ? to[ i] : r) ;
}
}
int cmp ( node x, node y) {
if ( x. mx == y. mx) return x. mi < y. mi;
return x. mx < y. mx;
}
int cmp1 ( node x, node y) {
return x. r < y. r;
}
int low ( int x) {
return x & ( - x) ;
}
void ins ( int k, int c) {
for ( int i = k; i <= n; i + = low ( i) ) f[ i] + = c;
}
int ct ( int k) {
int s = 0 ;
for ( int i = k; i; i - = low ( i) ) s + = f[ i] ;
return s;
}
void ds ( int l, int r, int o) {
sort ( a + l, a + r + 1 , cmp) ;
for ( int i = l; i <= r; i++ ) {
if ( a[ i] . mx - a[ i] . mi == K) {
ans + = ( i - l - ct ( a[ i] . mi - 1 ) ) * o;
}
else if ( a[ i] . mx - a[ i] . mi < K) ans + = sum[ a[ i] . mx - K] * o;
sum[ a[ i] . mi] ++ ;
ins ( a[ i] . mi, 1 ) ;
}
for ( int i = l; i <= r; i++ ) sum[ a[ i] . mi] -- , ins ( a[ i] . mi, - 1 ) ;
}
void calc ( int k) {
tot = 0 ;
dfs1 ( k, 0 , n + 1 , 0 , 0 ) ;
sort ( a + 1 , a + tot + 1 , cmp) ;
for ( int i = 1 ; i <= tot; i++ ) if ( a[ i] . mx - a[ i] . mi == K) ans++ ;
ds ( 1 , tot, 1 ) ;
sort ( a + 1 , a + tot + 1 , cmp1) ;
int la = 1 ;
for ( int i = 1 ; i <= tot; i++ ) {
if ( i == tot || a[ i] . r != a[ i + 1 ] . r) {
ds ( la, i, - 1 ) ;
la = i + 1 ;
}
}
}
void solve ( int k) {
dfs ( k, 0 ) ;
s = si[ k] , mi = n + 1 ;
find ( k, 0 ) ;
calc ( rt) ;
vi[ rt] = 1 ;
for ( int i = last[ rt] ; i; i = nxt[ i] ) if ( ! vi[ to[ i] ] ) solve ( to[ i] ) ;
}
int main ( ) {
int i, x, y, w;
scanf ( "%d%d" , & n, & K) ;
for ( i = 1 ; i < n; i++ ) {
scanf ( "%d%d%d" , & x, & y, & w) ;
add ( x, y, w) , add ( y, x, w) ;
}
solve ( 1 ) ;
printf ( "%lld
" , ans) ;
return 0 ;
}