COCI 2020/2021 Svjetlo
题目大意
求最短的树上路径(可以重复经过点或边)长度使得经过每个点的次数满足给定的奇偶性。树的大小为
N
N
N 。
N
≤
500000
Nle 500000
N ≤ 5 0 0 0 0 0
题解
路径是可以重复的,简单的树形DP可能难以处理,考虑路径的拼接。 设
f
i
,
j
,
k
f_{i,j,k}
f i , j , k 表示第
i
i
i 个点的子树内(除了自己)的奇偶性已经满足,且子树内(包括自己)的路径端点数有
j
j
j 个,第
i
i
i 个点的奇偶性为
k
k
k 的最短路径长度,其中
j
∈
{
0
,
1
,
2
}
,
k
∈
{
0
,
1
}
jin{0,1,2},kin{0,1}
j ∈ { 0 , 1 , 2 } , k ∈ { 0 , 1 } 。 转移的时候有很多种情况,但它们都是类似的,端点个数(状态第二维)的转移有: 1、儿子子树内均为
0
0
0 个端点 –> 自己子树
0
0
0 个端点 2、儿子子树内均为
0
0
0 个端点 + 自己作为某一个端点 –> 自己子树内
1
1
1 个端点 3、儿子子树内均为
0
0
0 个端点 + 自己作为两个端点 –> 自己子树内
2
2
2 个端点 4、一个儿子子树内
1
1
1 个端点 + 其他儿子子树内均为
0
0
0 个端点 –> 自己子树内
1
1
1 个端点 5、一个儿子子树内
1
1
1 个端点 + 其他儿子子树内均为
0
0
0 个端点 + 自己作为某个端点 –> 自己子树内
2
2
2 个端点 6、一个儿子子树内
2
2
2 个端点 + 其他儿子子树内均为
0
0
0 个端点 –> 自己子树内
2
2
2 个端点 7、两个儿子子树内各
1
1
1 个端点 + 其他儿子子树内均为
0
0
0 个端点 –> 自己子树内
2
2
2 个端点 第二维的
j
j
j 可以理解为是伸出了多少个“头”,然后每个子树相连拼接上,再用剩下的“头”继续往上转移。
0
0
0 和
2
2
2 都是两个“头”,
1
1
1 是一个“头”。 儿子之间合并的时候要注意答案所求的是路径点的个数,所以不能把每个儿子所有“2”都延长
2
2
2 的长度到父亲,不然儿子之间相接时会算重,而应该少延长一个”头“,最后更新答案时再只加多
1
1
1 。 如何保证儿子子树内的奇偶性都满足条件?如果从儿子节点尚不满足奇偶性的点转移时,需要多加上
2
2
2 的长度,表示到了父亲再往下到儿子走一个来回。 至于第三维是转移到当前节点的
0
0
0 还是
1
1
1 ,需要看转移上来的偶儿子(指
j
j
j 为偶数的儿子)个数的奇偶性。 还要注意,根节点剩下的两个“头”会相连,不仅答案会减
1
1
1 ,而且对他而言奇偶性还会再多变一次。 这样就做完了吗? 写到后面可能会很容易忽略的是,这样写会默认每个节点都至少被经过一次,而其实并不然,所以根节点设为任意一个奇偶性条件为
1
1
1 的点,同时在枚举儿子转移时,若整个子树都已经满足了就直接跳过。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 500010
int last[ N] , nxt[ N * 2 ] , to[ N * 2 ] , len = 0 ;
int f[ N] [ 3 ] [ 2 ] , a[ N] , s[ N] ;
void add ( int x, int y) {
to[ ++ len] = y;
nxt[ len] = last[ x] ;
last[ x] = len;
}
void dfs ( int k, int fa) {
int s0 = 0 , s1 = 1e9 , s2 = 1e9 , s3 = 1e9 , s4 = 1e9 , s5 = 1e9 , s6 = 1e9 , s7 = 1e9 ;
int t0, t1;
if ( ! a[ k] ) s[ k] ++ ;
for ( int i = last[ k] ; i; i = nxt[ i] ) if ( to[ i] != fa) {
int x = to[ i] ;
dfs ( x, k) ;
if ( ! s[ to[ i] ] ) continue ;
s[ k] + = s[ to[ i] ] ;
t0 = min ( s7 + f[ x] [ 0 ] [ 1 ] + 1 , s6 + f[ x] [ 0 ] [ 0 ] + 3 ) , t1 = min ( s6 + f[ x] [ 0 ] [ 1 ] + 1 , s7 + f[ x] [ 0 ] [ 0 ] + 3 ) ;
s6 = t0, s7 = t1;
t0 = min ( s2 + f[ x] [ 1 ] [ 1 ] , s3 + f[ x] [ 1 ] [ 0 ] + 2 ) , t1 = min ( s3 + f[ x] [ 1 ] [ 1 ] , s2 + f[ x] [ 1 ] [ 0 ] + 2 ) ;
s6 = min ( s6, t0) , s7 = min ( s7, t1) ;
t0 = min ( s5 + f[ x] [ 0 ] [ 1 ] + 1 , s4 + f[ x] [ 0 ] [ 0 ] + 3 ) , t1 = min ( s4 + f[ x] [ 0 ] [ 1 ] + 1 , s5 + f[ x] [ 0 ] [ 0 ] + 3 ) ;
s4 = t0, s5 = t1;
t0 = min ( s1 + f[ x] [ 2 ] [ 1 ] + 1 , s0 + f[ x] [ 2 ] [ 0 ] + 3 ) , t1 = min ( s0 + f[ x] [ 2 ] [ 1 ] + 1 , s1 + f[ x] [ 2 ] [ 0 ] + 3 ) ;
s4 = min ( s4, t0) , s5 = min ( s5, t1) ;
t0 = min ( s3 + f[ x] [ 0 ] [ 1 ] + 1 , s2 + f[ x] [ 0 ] [ 0 ] + 3 ) , t1 = min ( s2 + f[ x] [ 0 ] [ 1 ] + 1 , s3 + f[ x] [ 0 ] [ 0 ] + 3 ) ;
s2 = t0, s3 = t1;
t0 = min ( s0 + f[ x] [ 1 ] [ 1 ] , s1 + f[ x] [ 1 ] [ 0 ] + 2 ) , t1 = min ( s1 + f[ x] [ 1 ] [ 1 ] , s0 + f[ x] [ 1 ] [ 0 ] + 2 ) ;
s2 = min ( s2, t0) , s3 = min ( s3, t1) ;
t0 = min ( s1 + f[ x] [ 0 ] [ 1 ] + 1 , s0 + f[ x] [ 0 ] [ 0 ] + 3 ) , t1 = min ( s0 + f[ x] [ 0 ] [ 1 ] + 1 , s1 + f[ x] [ 0 ] [ 0 ] + 3 ) ;
s0 = t0, s1 = t1;
}
f[ k] [ 0 ] [ a[ k] ] = s1 + 1 ;
f[ k] [ 0 ] [ a[ k] ^ 1 ] = s0 + 1 ;
f[ k] [ 1 ] [ a[ k] ] = min ( s3, s1) + 1 ;
f[ k] [ 1 ] [ a[ k] ^ 1 ] = min ( s2, s0) + 1 ;
f[ k] [ 2 ] [ a[ k] ] = min ( min ( s0 + 2 , s5 + 1 ) , min ( s6 + 2 , s2 + 2 ) ) ;
f[ k] [ 2 ] [ a[ k] ^ 1 ] = min ( min ( s1 + 2 , s4 + 1 ) , min ( s7 + 2 , s3 + 2 ) ) ;
}
int main ( ) {
int n, i, x, y;
scanf ( "%d
" , & n) ;
for ( i = 1 ; i <= n; i++ ) {
a[ i] = getchar ( ) - '0' ;
}
for ( i = 1 ; i < n; i++ ) {
scanf ( "%d%d" , & x, & y) ;
add ( x, y) , add ( y, x) ;
}
for ( i = 1 ; i <= n; i++ ) if ( ! a[ i] ) break ;
dfs ( i, 0 ) ;
printf ( "%d
" , f[ i] [ 2 ] [ 0 ] - 1 ) ;
fclose ( stdin ) ;
fclose ( stdout ) ;
return 0 ;
}