题目
题目链接:https://www.luogu.com.cn/problem/P4383
小 L 最近沉迷于塞尔达传说:荒野之息(The Legend of Zelda: Breath of The Wild)无法自拔,他尤其喜欢游戏中的迷你挑战。
游戏中有一个叫做 LCT 的挑战,它的规则是这样子的:现在有一个 (N) 个点的树,每条边有一个整数边权 (v_i),若 (v_i geq 0),表示走这条边会获得 (v_i) 的收益;若 (v_i lt 0) ,则表示走这条边需要支付 (-v_i) 的过路费。小 L 需要控制主角 Link 切掉(Cut)树上的恰好 (K) 条边,然后再连接 (K) 条边权为 0 的边,得到一棵新的树。接着,他会选择树上的两个点 (p,q),并沿着树上连接这两点的简单路径从 (p) 走到 (q),并为经过的每条边支付过路费/ 获取相应收益。
海拉鲁大陆之神 TemporaryDO 想考验一下 Link。他告诉 Link,如果 Link 能切掉合适的边、选择合适的路径从而使 总收益 - 总过路费 最大化的话,就把传说中的大师之剑送给他。
小 L 想得到大师之剑,于是他找到了你来帮忙,请你告诉他,Link 能得到的 总收益 - 总过路费 最大是多少。
(n,kleq 3 imes 10^5),(|v_i|leq 10^6)。
思路
问题等价于在树上选择 (k) 条不交路径,使得它们的权值之和最大。
设 (f[x][i][0/1/2]) 表示点 (x) 的子树内,选择了 (i) 条路径,其中点 (x) 相连的边选择了 (0/1/2) 条的最大权值和。
转移的话,就考虑新加入 (x) 的一个儿子 (y),讨论一下这条边选不选就行了。时间复杂度 (O(nk^2))。
这样显然过不了。题解愉快的告诉我们,随着 (k) 不断增大,(f) 是一个凸函数,直接上 WQS 二分。二元组 (f[x][0/1/2]) 表示点 (x) 的子树内,选择若干条路径,其中点 (x) 相连的边选择了 (0/1/2) 条的最大权值和,以及此时选择的路径最多是多少条。
至于为什么 (f) 是一个凸函数我也不清楚,反正出题人说是就是吧(什
时间复杂度 (O(nlog (nV)))。(V) 是值域。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=300010;
const ll Inf=1e18;
int n,m,tot,head[N];
ll ans;
struct edge
{
int next,to,dis;
}e[N*2];
struct node
{
ll a,b;
friend node operator +(node x,node y)
{
return (node){x.a+y.a,x.b+y.b};
}
}f[N][3];
node max(node x,node y)
{
if (x.a>y.a) return x;
if (x.a<y.a) return y;
return (node){x.a,max(x.b,y.b)};
}
void add(int from,int to,int dis)
{
e[++tot]=(edge){head[from],to,dis};
head[from]=tot;
}
void dfs(int x,int fa,ll mid)
{
f[x][0]=(node){0,0}; f[x][1]=(node){-Inf,0}; f[x][2]=(node){-mid,1};
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
dfs(v,x,mid);
f[x][2]=max(f[x][2],f[x][2]+max(f[v][0],max(f[v][1],f[v][2])));
f[x][2]=max(f[x][2],f[x][1]+max(f[v][0]+(node){e[i].dis,0},f[v][1]+(node){e[i].dis+mid,-1}));
f[x][1]=max(f[x][1],f[x][1]+max(f[v][0],max(f[v][1],f[v][2])));
f[x][1]=max(f[x][1],f[x][0]+max(f[v][1]+(node){e[i].dis,0},f[v][0]+(node){e[i].dis-mid,1}));
f[x][0]=max(f[x][0],f[x][0]+max(f[v][0],max(f[v][1],f[v][2])));
}
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
m++;
for (int i=1,x,y,z;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z); add(y,x,z);
}
ll l=-3e11,r=3e11,mid;
while (l<=r)
{
mid=(l+r)>>1;
dfs(1,0,mid);
node res=max(f[1][0],max(f[1][1],f[1][2]));
if (res.b>=m) l=mid+1,ans=res.a+res.b*mid;
else r=mid-1;
}
cout<<ans;
return 0;
}