【HAOI2015】树上染色
【题目描述】
有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。问收益最大值是多少。
【输入格式】
第一行两个整数N,K。
接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。输入保证所有点之间是联通的。
【输出格式】
输出一个正整数,表示收益的最大值。
【输入样例1】
3 1
1 2 1
1 3 2
【输出样例1】
3
【输入样例2】
5 2
1 2 3
1 5 1
2 3 1
2 4 2
【输出样例2】
17
【样例解释】
在第二个样例中,将点1,2染黑就能获得最大收益。
【数据范围】
对于30%的数据,N<=20
对于50%的数据,N<=100
对于100%的数据,N<=2000,0<=K<=N。
题解:
这题一看就是dp,单个点无法产生贡献,只能是两个黑点或两个白点才能产生贡献
如果我们dp围绕点来进行,首先你无法定义一种状态,其次你不方便转移
所以我们考虑一条边能产生的贡献
首先一条边能产生的贡献为:
Wi=dis×((这条边左边的黑点×这条边右边的黑点)+(这条边左边的白点×这条边右边的白点))
我们考虑一对点产生的贡献,这一条边L左侧的黑点要想连接L右侧的黑点,必定经过L,所以L会有贡献
有几对这样的黑点,边L就被经过几次,白点同理,所以得到了上面的方程
我们设f[i][j]表示以i为根节点的子树中染j个黑点的最大收益,
size[x]表示以x为根的子树大小,目标:f[1][k];
设当前搜索到以x为根的子树,设son为x的一个儿子,则:
$f[x][j]=max(f[x][p]+f[son][j-p]+Wi),$
$j<=min(size[x],k),p<=min(size[son],k)$
Wi表示x与son所连边产生的贡献,在上面已经解释过了
更新f的时候,如果正着推,就需要把f存到一个临时数组里;循环完后再赋值给f,如果倒着枚举直接更新就好了
当然你还可以减少枚举的数量,即j<=min(size[x],k),p<=min(size[son],k)。
下面给出两份代码:
#include<iostream> #include<cstdio> #include<cstring> #define ll long long #define MAXN 2005 using namespace std; ll n,k; ll fr[MAXN<<1],to[MAXN<<1],nxt[MAXN<<1],pre[MAXN],cnt=0,dis[MAXN<<1]; void add(ll u,ll v,ll w){ cnt++,fr[cnt]=u,to[cnt]=v,nxt[cnt]=pre[u],pre[u]=cnt,dis[cnt]=w; } ll f[MAXN][MAXN],temp[MAXN],size[MAXN]; void dfs(ll x,ll fa){ size[x]=1; for(ll i=pre[x];i;i=nxt[i]){ ll y=to[i]; if(y==fa) continue; dfs(y,x); ll num_b1=min(k,size[x]),num_b2=min(k,size[y]),l; memset(temp,0,sizeof(temp)); for(ll j=0;j<=num_b1;j++){ for(ll p=0;p<=num_b2&&p+j<=k;p++){ l=dis[i]*((k-p)*p+(size[y]-p)*(n-k-size[y]+p)); temp[j+p]=max(temp[j+p],f[x][j]+f[y][p]+l); } } ll m=min(num_b1+num_b2,k); for(ll j=0;j<=m;j++) f[x][j]=temp[j]; size[x]+=size[y]; } } int main(){ scanf("%lld%lld",&n,&k); for(ll i=1,u,v,d;i<n;i++){ scanf("%lld%lld%lld",&u,&v,&d); add(u,v,d),add(v,u,d); } dfs(1,0); printf("%lld ",f[1][k]); return 0; }
#include <iostream> #include <cstdio> #include <cstring> #define ll long long #define MAXN 2005 using namespace std; ll n, k; ll to[MAXN << 1], nxt[MAXN << 1], pre[MAXN], cnt = 0, dis[MAXN << 1]; void add(ll u, ll v, ll w) { cnt++, to[cnt] = v, nxt[cnt] = pre[u], pre[u] = cnt, dis[cnt] = w; } ll f[MAXN][MAXN], temp[MAXN][MAXN], size[MAXN]; void dfs(ll x, ll fa) { size[x] = 1; for (ll i = pre[x]; i; i = nxt[i]) { ll y = to[i]; if (y == fa) continue; dfs(y, x); ll num_b1 = min(k, size[x]), l; for (ll j = num_b1; j >= 0; j--) { ll num_b2 = min(k - j, size[y]); for (ll p = num_b2; p >= 0; p--) { l = dis[i] * ((k - p) * p + (size[y] - p) * (n - k - size[y] + p)); f[x][j + p] = max(f[x][j + p], f[x][j] + f[y][p] + l); } } size[x] += size[y]; } } int main() { scanf("%lld%lld", &n, &k); for (ll i = 1, u, v, d; i < n; i++) { scanf("%lld%lld%lld", &u, &v, &d); add(u, v, d), add(v, u, d); } dfs(1, 0); printf("%lld ", f[1][k]); return 0; }