准确的说应该叫树上分组背包?并不知道我写的这个叫啥
设计状态f[u][j]为在以点u为根的子树中有j个黑点,转移的时候另开一个数组,不能在原数组更新(因为会用到没更新时候的状态),方程式为g[j+k]=max(g[j+k],f[u][j]+f[e[i].to][k]+(k*(m-k)+(si[e[i].to]-k)*(n-m-(si[e[i].to]-k)))*e[i].va);,其中m'为题目描述中的k。
关于这个方程的由来,考虑一条边对答案的贡献,显然是这条边一边的黑点数量*另一边的黑点数量+一边的白点数量*另一边的白点数量,再乘上边权。
注意:size要在dp中加,dp完一棵子树再加上这棵子树的size。
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=2005;
long long n,m,h[N],cnt,f[N][N],si[N],g[N];
struct qwe
{
long long ne,to,va;
}e[N<<1];
long long read()
{
long long r=0,f=1;
char p=getchar();
while(p>'9'||p<'0')
{
if(p=='-')
f=-1;
p=getchar();
}
while(p>='0'&&p<='9')
{
r=r*10+p-48;
p=getchar();
}
return r*f;
}
void add(long long u,long long v,long long w)
{
cnt++;
e[cnt].ne=h[u];
e[cnt].to=v;
e[cnt].va=w;
h[u]=cnt;
}
// void dfs(long long u,long long fa)
// {
// si[u]=1;
// for(long long i=h[u];i;i=e[i].ne)
// if(e[i].to!=fa)
// {
// dfs(e[i].to,u);
// si[u]+=si[e[i].to];
// }
// }
void dp(long long u,long long fa)
{
si[u]=1;
for(long long i=h[u];i;i=e[i].ne)
if(e[i].to!=fa)
{
dp(e[i].to,u);
memset(g,0,sizeof(g));
for(long long j=0;j<=min(m,si[u]);j++)
for(long long k=0;k<=min(m,si[e[i].to]);k++)
if(j+k<=m)
g[j+k]=max(g[j+k],f[u][j]+f[e[i].to][k]+(k*(m-k)+(si[e[i].to]-k)*(n-m-(si[e[i].to]-k)))*e[i].va);
for(long long j=0;j<=m;j++)
f[u][j]=g[j];
si[u]+=si[e[i].to];
}
}
int main()
{
n=read(),m=read();
for(long long i=1;i<n;i++)
{
long long x=read(),y=read(),z=read();
add(x,y,z);
add(y,x,z);
}
//dfs(1,0);
dp(1,0);
printf("%lld
",f[1][m]);
return 0;
}