神仙级别的树形 dp。
u1s1 这种代码很短但巨难理解的题简直是我的梦魇
首先这种题目一看就非常可以 DP 的样子,但直接一维状态的 DP 显然无法表示所有情况。注意到对于这类统计一个路径上权值之和的最值这样的问题,我们可以考虑借鉴 P4383 林克卡特树 的套路,即在 DP 状态中多记录一维 (j) 存储当前路径的延伸情况。但是这道题与 林克卡特树 的不同之处在于路径并非是简单路径,即一条路径可以先向上走一段,再向下走一段,接着再向上走一段。因此考虑这样设计 DP 状态:我们考察路径的起点和终点 (x,y),并设 (dp_{u,j,k}) 表示目前 (u) 考虑到 (u) 的子树,(x,y) 中恰好有 (j) 个在 (u) 的子树中 ((jin[0,2])),目前除了 (u) 的状态为 (k) 之外,(u) 子树内其剩余所有点的状态都变为 (1) 的最短序列长度。
初始三种情况都只有一条单一的 (u o u) 的路径,即 (dp_{u,0,a_uoplus 1}=dp_{u,1,a_uoplus 1}=dp_{u,2,a_uoplus 1}=1)。考虑怎样合并两棵子树的路径,即从 (dp_{u,j,k}) 与 (dp_{v,p,q}) 推出新的 (dp_{u,j,k}),其中 (u) 的 (v) 的父亲,(p+jle 2)。这一部分的转移略有亿点点恶心,下面将分条叙述:
- 若 (p=1),那么 (v) 子树内的路径应当是一条不能再向下延伸,但可以再继续向上延伸的路径 (x o v),同理由于 (p+qle 2),(u) 子树内对应的路径必然有一个端点会向上延伸,即一条 (u o y) 的路径,其中 (y) 可以等于 (u),即 (j=0) 的情况,也可以不等于 (u),即 (j=1) 的情况,那么:
- 如果 (k=1),那么此时直接将两部分拼在一起即可,即 (x o v o u o y),(dp_{u,j,k}+dp_{v,p,q} o newdp_{u,j+p,q})
- 如果 (k=0),那么在合并时候还要花费 (2) 的代价将 (y) 的状态变为 (1),即 (x o v o u o v o u o y),此时 (y) 的状态也会改变,(dp_{u,j,k}+dp_{v,p,q}+2 o newdp_{u,j+p,qoplus 1})
- 若 (p=2),那么 (v) 子树内的路径本来应是一段 (x o y),而 (u) 子树内的路径应为一段 (u o u) 的回路,那么合并两段路径时,由于本来经过 (v) 的是一段完整的路径 (x o y),而现在我们要将 (u o u) 的路径嵌进去,因此我们要将 (x o y) 的路径拆成 (x o v) 和 (v o y),然后将两段按 (x o v o u o u o v o x) 的顺序合并,至于额外代价……首先将两个 (v) 拆开来需要有 (1) 的代价。其次如果 (k=1),那么合并时将 (v) 拆成了两段会将 (v) 的状态改变一次,变成 (0),此时我们还需要在 (u,v) 反复横跳一次之后才能将 (v) 的状态变为 (1)。而如果 (k=0),则不需要,因此:
- 如果 (k=1),那么 (dp_{u,j,k}+dp_{v,p,q} o newdp_{u,j+p,qoplus 1}+3)
- 如果 (k=0),那么 (dp_{u,j,k}+dp_{v,p,q} o newdp_{u,j+p,q}+1)
- 若 (p=0),那么 (v) 子树内的路径就是一段回路 (v o v),我们合并两条路径时,由于两个 (v) 都要向上延伸,因此我们要将 (u) 子树内一段路径拆成两段 (x o u) 和 (y o u),然后将 (v o v) 嵌进去,即 (x o u o v o v o u o y),这样 (u) 的状态首先会改变一次,产生 (1) 的基本代价,而:
- 如果 (k=1),那不用额外代价,但 (u) 的状态会改变 (1),(dp_{u,j,k}+dp_{v,p,q} o newdp_{u,j+p,qoplus 1}+1)
- 如果 (k=0),那还要在 ((u,v)) 上反复横跳一次将 (v) 的状态改为 (1),同时 (u) 的状态又改变了一次,这样一来一回贡献抵消了,因此 (dp_{u,j,k}+dp_{v,p,q} o newdp_{u,j+p,q}+3)
状态转移情况就这么多。注意一些细节的情况,比方说一个子树如果所有点的状态都是 (1),那我们大可不必进入这个子树,直接 continue
即可,还有就是要以一个 (s_i='0') 的点为根 DFS。
总复杂度 (mathcal O(n))
const int MAXN=5e5;
int n,a[MAXN+5],rt=-1;char s[MAXN+5];
int hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int sum[MAXN+5],dp[MAXN+5][3][2];
void dfs(int x,int f){
sum[x]=(!a[x]);
dp[x][0][a[x]^1]=dp[x][1][a[x]^1]=dp[x][2][a[x]^1]=1;
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];if(y==f) continue;dfs(y,x);
sum[x]|=sum[y];if(!sum[y]) continue;
static int tmp[3][2];memset(tmp,63,sizeof(tmp));
for(int i=0;i<2;i++) for(int j=0;j<2;j++)
for(int p=0;p<3;p++) for(int q=0;p+q<3;q++){
if(q==0) chkmin(tmp[p+q][i^j],dp[x][p][i]+dp[y][q][j]+3-j*2);
if(q==1) chkmin(tmp[p+q][i^j^1],dp[x][p][i]+dp[y][q][j]+2-j*2);
if(q==2) chkmin(tmp[p+q][i^j],dp[x][p][i]+dp[y][q][j]+3-(j^1)*2);
}
memcpy(dp[x],tmp,sizeof(dp[x]));
}
}
int main(){
scanf("%d%s",&n,s+1);
for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),adde(u,v),adde(v,u);
for(int i=1;i<=n;i++) a[i]=s[i]-'0';
for(int i=1;i<=n;i++) if(!a[i]) rt=i;
memset(dp,63,sizeof(dp));
dfs(rt,0);printf("%d
",dp[rt][2][1]);
return 0;
}