一道树形DP
(Translate)
有一颗(n)个节点的树
第(i)个节点权值为(a_i) ((n<=10^6,-100<=a_i<=100))
问是否能够删除掉两条边,使得该树分成三个不为空的部分,并且每部分权值之和相等。
无解输出(−1) 否则输出要删除边((u−>v))的v节点序号。
(Solution)
很明显我们可以在(DFS)遍历树的时候记录子树和这个节点的权值和,我们要要将这个树分成三部分,每部分权值和都为(Sum/3),所以当我们找到当前的和为(Sum/3)时,就可以(cnt++),然后记录下这个点,千万不要忘了将现在的求和数组清零,因为如果从这里断开,这个节点对它的父节点将没有贡献,最后找完,如果(cnt<=2),说明分成的子树个数(<=2)而不是断边(<=2),这点要注意。这时候输出(-1)即可,否则说明这颗树可以分成三部分,输出(ans[1],ans[2])即可。
剪枝
这个题剪枝十分明显,也比较好想,我们只要在(dp)前判断所有的权值和能否被(3)整除即可,优化十分有效。
(Code)
#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#define LL long long
#define maxn 1000001
using namespace std;
struct Edge{
int nxt,to,from;
}edge[maxn*2];//因为这是个树
int num_edge,head[maxn],a[maxn],n,fa,total,sum,cnt,ans[maxn],ksum[maxn];
inline int qread(){
char ch=getchar();
int f=1,x=0;
while(ch>'9'||ch<'0'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline void addedge(int from,int to){ //加上from的原因是我们要输出他的from
edge[++num_edge].nxt=head[from];
edge[num_edge].from=from;
edge[num_edge].to=to;
head[from]=num_edge;
}
int dfs(int x,int fath){
ksum[x]=a[x];
for(int i=head[x];i;i=edge[i].nxt){//找祖先
int y=edge[i].to;
if(y!=fath){
dfs(y,x);
ksum[x]+=ksum[y];
}
}
if(ksum[x]==sum){ //当满足了之后
cnt++;
ans[cnt]=x;
ksum[x]=0;
}
}
int main(){
n=qread();//快读,否则的话大数据卡死
for(int i=1;i<=n;i++){
int x;
x=qread();
a[i]=qread();
if(x!=0) {
addedge(i,x);
addedge(x,i);
}
else fa=i;
total+=a[i];
}
if(total%3!=0) printf("-1
");
else{
sum=total/3;
dfs(fa,0);
if(cnt>=3) printf("%d %d
",ans[2],ans[1]);
else printf("-1
");
}
return 0;
}
(Addition)
在写这里之前还有一个问题不太理解,那就是为什么要写(cnt>=3)而不能写(cnt==3),不是只会分成三个块吗?
下面这段代码我倒是能理解
#include<bits/stdc++.h>
using namespace std;
struct Edge
{
int to,nxt;
Edge(){}
Edge(int to,int nxt):to(to),nxt(nxt){}
}e[1000010];
int head[1000010],cnt;
void addedge(int u,int v)
{
e[++cnt]=Edge(v,head[u]);
head[u]=cnt;
}
int n,m;
int siz[1000010];
int ans=0;
bool cut=0;
void dfs(int now,int fa)
{
for(int i=head[now];i;i=e[i].nxt)
{
int vs=e[i].to;
if(vs==fa) continue;
dfs(vs,now);
siz[now]+=siz[vs];
}
if(siz[now]==m)
{
if(cut)
{
cout<<ans<<' '<<now<<'
';//找到第二处直接输出
exit(0);
}
else
{
ans=now;
cut=1;
}
siz[now]=0;
}
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n;
int root;
for(int i=1;i<=n;i++)
{
int tmp;
cin>>tmp>>siz[i];
m+=siz[i];
if(tmp) addedge(tmp,i);
else root=i;
}
if(m%3) return cout<<-1<<'
',0;
m/=3;
dfs(root,0);
cout<<-1<<'
';
}