树形dp之换根dp
换根dp是树形dp这一类中我觉得比较难的一类。
一般的树形dp都只需要从子树往父亲推,然而换根dp则需要从父亲往子树推,接下来写写我学习换根dp的几个例题。
例题1 Computer
题目大意:给你一棵树,然后问你每一个点具体其他点最远的距离是多少。
解题:这个题目首先任意找一个点为根节点,然后求出这个点的最远距离,除了这个最远距离之外,还需要存一个次远距离,这是第一个dfs。第二个dfs就比较重要了,这个才是换根dp的精髓,对于一个节点从父亲节点推过来,如果该节点是父亲节点最长链的一环,那么这个节点只能从父亲节点的次大值更新(值得注意的是,我们开始求最大值和次大值的同时就应该保证了最大值和次大值不是同一条链),如果不是最长链的一环,那么直接可以从父亲节点的最大值更新。
#include<queue>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#define inf64 0x3f3f3f3f3f3f3f3f
using namespace std;
typedef long long ll;
const int maxn=1e5+7;
const int mod=1e9+7;
int head[maxn],cnt;
struct node{
ll v,w,nxt;
node(ll v=0,ll w=0,ll nxt=0):v(v),w(w),nxt(nxt){}
}e[maxn*2];
void add(int u,int v,ll w){
e[++cnt]=node(v,w,head[u]);
head[u]=cnt;
e[++cnt]=node(u,w,head[v]);
head[v]=cnt;
}
ll dp1[maxn],dp2[maxn];
ll dfs1(int u,int pre){
ll res1=0,res2=0;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==pre) continue;
ll f=dfs1(v,u)+e[i].w;
if(f>=res1){
res2=res1;
res1=f;
}
else res2=max(res2,f);
}
dp1[u]=res1,dp2[u]=res2;
return res1;
}
void dfs2(int u,int pre){
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==pre) continue;
if(dp1[u]==dp1[v]+e[i].w){
if(dp2[u]+e[i].w>=dp1[v]){
dp2[v]=dp1[v];
dp1[v]=dp2[u]+e[i].w;
}
else dp2[v]=max(dp2[v],dp2[u]+e[i].w);
}
else{
if(dp1[u]+e[i].w>=dp1[v]){
dp2[v]=dp1[v];
dp1[v]=dp1[u]+e[i].w;
}
else dp2[v]=max(dp2[v],dp1[u]+e[i].w);
}
dfs2(v,u);
}
}
int main(){
int n;
while(scanf("%d",&n)!=EOF){
cnt=0;
for(int i=1;i<=n;i++) head[i]=0;
for(int i=2;i<=n;i++){
ll v,w;
scanf("%lld%lld",&v,&w);
add(i,v,w);
}
dfs1(1,0);
dfs2(1,0);
for(int i=1;i<=n;i++) printf("%lld
",dp1[i]);
}
return 0;
}
例2 湖南2019省赛2019
题目大意:给你一棵树,树的每一条边都有一个边权值,然后问,任意两个节点之间的距离是2019的偶数倍,这样的节点对数有多少,提示(u,v)和(v,u)是一样的。
这个题目我也是用换根dp求的,dp[ u ] [ i ] 表示以节点u为一个端点,u到其他节点的距离对2019取模为i的路径数。
所以按照换根dp的套路,第一个dfs求以1为根节点的路径数量,第二个dfs由父亲向儿子转移,求出儿子的dp。
这里有一个地方要很注意,这个我在代码中强调了。
#include<queue>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#define inf64 0x3f3f3f3f3f3f3f3f
using namespace std;
typedef long long ll;
const int maxn=2e4+7;
const int mod=2019;
int head[maxn],cnt;
struct node{
int v,w,nxt;
node(int v=0,int w=0,int nxt=0):v(v),w(w),nxt(nxt){}
}e[maxn*2];
void add(int u,int v,int w){
e[++cnt]=node(v,w,head[u]);
head[u]=cnt;
e[++cnt]=node(u,w,head[v]);
head[v]=cnt;
}
ll dp[maxn][2022],ans=0,num[2022];
void dfs2(int u,int pre){
for(int i=0;i<2019;i++) dp[u][i]=0;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==pre) continue;
dfs2(v,u);
dp[u][w]++;
for(int j=0;j<2019;j++){
dp[u][(j+w)%mod]+=dp[v][j];
}
}
}
void dfs3(int u,int pre){
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==pre) continue;
for(int j=0;j<2019;j++) num[(j+w)%mod]=dp[u][(j+w)%mod]-dp[v][j];
num[w]--,dp[v][w]++;//!!! 这个地方要超级注意,我在这里wa了一天,很容易漏掉的。
//这里就是对于节点(u,v) 首先num[w]-- 因为对于u这个节点,不能和v的子树相连
//那么肯定也不能和v相连 所以num[w]--,对于v这个节点(v,u)这也是一条路径,所以要加上。
for(int j=0;j<2019;j++) dp[v][(j+w)%mod]+=num[j];
dfs3(v,u);
}
ans+=dp[u][0];
}
void init(int n){
cnt=ans=0;
for(int i=0;i<=2*n;i++) head[i]=0;
}
int main(){
int n;
while(scanf("%d",&n)!=EOF){
init(n);
for(int i=1;i<n;i++){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
add(u,v,w%mod);
}
dfs2(1,0),dfs3(1,0);
printf("%lld
",ans/2);
}
return 0;
}