题干:
有一棵点数为 N 的树,树边有边权。给你一个在 0∼N 之内的正整数 K,你要在这棵树中选择 K 个点,将其染成黑色,并将其他的 N−K 个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。问收益最大值是多少。
题解:
首先,看到黑点到黑点、白点与白点,其实这就告诉我们黑点与白点是可以整体互换的(可以由黑点染成白色变为白点染成黑色),这就代表我们的 K 一定 <=(n/2)。
然后看到本题在树上进行操作,那么就应该是树上dp(合理分析。。。)
相应地,我们先 dfs 出所有子树的大小,然后想一下 dp 状态转移方程。
第一维不出意外应是以 i 为根节点的树(然后怎么怎么样~~~)。
第二维呢?我们好像只能是定义为以 i 为根节点的树,这个树中有 j 个节点被染色的情况。
我们每染两个点,就会对全局都造成影响,所以 dp 的结果值应定义为对全局的贡献,而不是以这个节点为根节点的子树的贡献。
再看一下这两个相同颜色的点是如何作出贡献的。我们可以发现,我们要输出的结果就是所有边权乘上各自被经过的次数。
这每一个边各自被经过的次数怎么算呢?—— 我们可以看一下每一个边两旁的节点,发现两旁每有一对相同颜色的节点,它就会被经过一次。
最后用乘法分配律可得:针对每一条边,它被经过的次数就是:
左边白点个数×右边白点个数+左边黑点个数×右边黑点个数
最后利用树上背包来解决(体积为染色点的个数,价值就为对全局的贡献)。
(注意本题的背包与以往经典例题的方法不太一样,它有一定后效性,
只可以在更新部分最优的同时,进行现在贡献的统计;而不是先找出子节点的最优,再找出现在的最优进行相加)
我们再回过头来看一下我们的算法,好像是 O(n3) 的!!!(枚举一遍节点 × 大小为n的子树体积 × 大小为n的其余树体积)。。。但仔细分析一下其实是 O(n2) 的:
1、在实现过程中,其实我们在更新答案时是枚举了整个数的一个子树(设大小为 x)与其余部分(设大小为 y ),那么每一个小子树的复杂度为 O(x*y)
(为什么是 O(x*y) 呢?本题的核心在于将染色的节点数作为背包,那么每一棵子树的大小就是背包的最大体积,即在枚举时我们在 0~x 中又枚举了 0~y):
Σsizj=0 Σsiz2m=0
val=w*( m*(k-m) + (siz[son]-m)*(n-k-(siz[son]-m)) );
dp[x][m+j]=max(dp[x][m+j],dp[x][j]+dp[son][m]+val);
2、相应的大子树同理,那么我们用乘法结合律可得:我们相当于将每一个结点都乘过了除它以外的所有节点——n个节点×n个节点——O(n2)。
O(n2) 中注意要枚举子树的大小,而不是全扫一边,否则就一定会退化为O(n3) 。
Code:
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include<cstdio>
2 #include<cstring>
3 #define $ 2102
4 #define ll long long
5 using namespace std;
6 int n,first[$],tot,m,siz[$];
7 ll dp[$][$/2];
8 struct tree{ int to,next,w; }a[$*2];
9 inline int min(int x,int y){ return x<y?x:y; }
10 inline ll max(ll x,ll y){ return x>y?x:y; }
11 inline void add(int x,int y,int w){
12 a[++tot]=(tree){ y,first[x],w };
13 first[x]=tot;
14 a[++tot]=(tree){ x,first[y],w };
15 first[y]=tot;
16 }
17 inline void dfs(int x,int fa,int sum=1){
18 siz[x]=1;
19 if(first[x]==0) return ;
20 for(register int i=first[x];i;i=a[i].next){
21 int to=a[i].to;
22 if(to==fa) continue;
23 dfs(to,x); siz[x]+=siz[to];
24 }
25 for(register int i=first[x];i;i=a[i].next){
26 int to=a[i].to;
27 if(to==fa) continue;
28 int vv=min(m,siz[to]);
29 for(register int j=sum;j>=0;--j)
30 for(register int k=vv;k>=0;--k){
31 if(j+k<=m){
32 ll val=1ll*a[i].w*( 1ll*k*(m-k) + 1ll*(siz[to]-k)*(n-m-(siz[to]-k)) );
33 dp[x][k+j]=max(dp[x][k+j],dp[x][j]+dp[to][k]+val);
34 }
35 }
36 sum=min(m,sum+siz[to]);
37 }
38 }
39 signed main(){
40 scanf("%d%d",&n,&m); m=min(m,n-m);
41 for(register int i=1,x,y,w;i<n;++i) scanf("%d%d%d",&x,&y,&w),add(x,y,w);
42 dfs(1,0); printf("%lld
",dp[1][m]);
43 }