题意
有一个n
个点构成的树,每条边有一个边权d
。求最多有一个点度数超过k
的联通子图的边权和最大值。
分析
-
首先
k=0
时答案为0
-
dp[0][u]
代表以u
为根的子树中所有点度数都小于等于k
时的边权和最大值,且u
与它的父节点有连边。dp[1][u]
代表以u
为根的子树中存在一个点的度数大于k
时的边权和最大值,且u
与它的父节点有连边。 -
(dp[0][u]=max_{v_1,v_2,...,v_{k-1}}(sum_{i=v_1,v_2,...,v_{k-1}}dp[0][i])+d),其中
v
为u
的儿子节点,d
为u
的父亲节点到u
的边权值 -
(dp[1][u]=max(sum_vdp[0][v],max_{v_1,v_2,...,v_{k-1}}(sum_{i=v_1,v_2,...,v_{k-2}}dp[0][i]+dp[1][v_{k-1}]))+d)
-
(ans=max_u(dp[1][u],max_{v_1,v_2,...,v_{k}}(sum_{i=v_1,v_2,...,v_{k-1}}dp[0][i]+dp[1][v_k])))
-
其中(max_{v_1,v_2,...,v_{k}}(sum_{i=v_1,v_2,...,v_{k-1}}dp[0][i]+dp[1][v_k]))可以对
dp[0][v]
由大到小排序,然后对前k
个计算(sum_{i=1}^kdp[1][i]+dp[1][v]-dp[0][v]),对后cnt-k
个计算(sum_{i=1}^kdp[1][i]+dp[1][k]-dp[0][v]),复杂度为O(nlogn)
代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=2e5+5;
typedef long long ll;
int n,k;
struct Node{int to,next;ll d;}edge[maxn*2];
int head[maxn],ecnt;
int cnt[maxn];
ll Ans,ans[2][maxn];
int son[maxn];
void init()
{
memset(head,-1,sizeof(head[0])*(n+5));
memset(cnt,0,sizeof(cnt[0])*(n+5));
ecnt=0;
Ans=0;
}
void addedge(int u,int v,ll d)
{
edge[ecnt]={v,head[u],d};
head[u]=ecnt++;
edge[ecnt]={u,head[v],d};
head[v]=ecnt++;
cnt[u]++;cnt[v]++;
}
void dfs(int u,int fa,ll d)
{
ans[1][u]=ans[0][u]=d;
vector<ll>v0;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].to;
if(v==fa)continue;
dfs(v,u,edge[i].d);
ans[1][u]+=ans[0][v];
v0.push_back(ans[0][v]);
}
sort(v0.begin(),v0.end(),greater<ll>());
for(int i=0;i<min((ll)v0.size(),(ll)k-1);i++)
ans[0][u]+=v0[i];
for(int i=head[u],j=1;i!=-1;i=edge[i].next)
if(edge[i].to!=fa)
son[j]=edge[i].to,j++;
sort(son+1,son+cnt[u]+1,[](int i,int j){
return ans[0][i]>ans[0][j];
});
//return
int nn=min(k-1,cnt[u]);
ll ans1=d;
if(k>=2)
{
for(int i=1;i<=nn;i++)
ans1+=ans[0][son[i]];
for(int i=1;i<=nn;i++)
ans[1][u]=max(ans[1][u],ans1-ans[0][son[i]]+ans[1][son[i]]);
ans1-=ans[0][son[nn]];
for(int i=nn+1;i<=cnt[u];i++)
ans[1][u]=max(ans[1][u],ans1+ans[1][son[i]]);
}
else
ans[1][u]=max(ans[1][u],ans1);
//dp
nn=min(k,cnt[u]);
ll ans2=0;
for(int i=1;i<=nn;i++)
ans2+=ans[0][son[i]];
for(int i=1;i<=nn;i++)
Ans=max(Ans,ans2-ans[0][son[i]]+ans[1][son[i]]);
ans2-=ans[0][son[nn]];
for(int i=nn+1;i<=cnt[u];i++)
Ans=max(Ans,ans2+ans[1][son[i]]);
Ans=max(Ans,ans[1][u]);
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
int u,v;ll d;
scanf("%d%d",&n,&k);
init();
for(int i=1;i<=n-1;i++)
{
scanf("%d%d%lld",&u,&v,&d);
addedge(u,v,d);
}
if(k==0)
{
printf("0
");
continue;
}
for(int i=2;i<=n;i++)
cnt[i]--;
dfs(1,0,0);
printf("%lld
",Ans);
}
return 0;
}