又是一道看了题解的题
题意:给出一棵边权为 1 的树,要求一次走完所有的边并回到出发点,规定你可以选择连 k 条边权为 1 的边,并且新连的边必须且只能走一次,求最短路程
显然 k = 1 的时候答案为 (2n - 树的直径 - 1)
现在考虑 k = 2 的时候如何连第二条边
由于每条原来的边至少走一次,连边的策略需要尽可能保证所连接的两点及其 lca 之间的边经过次数 > 1,并选择最长的路径
但显然不是所有情况都有 满足条件的路径 能使得答案最优的
那接下来考虑如何在决策选点对时考虑上之前连的边的影响
以下来自题解
若新连出的环与之前连出的环不重叠,则答案会继续减小
若重叠,两环重叠的部分就不会被巡逻到,又因题目要求至少经过一次,我们不得不在合适的时候经过这条边并返回
最终的后果是两环重叠的部分由之前的“只需要经过一次”变成了“需要经过两次”
再想一下,对于原图,每条边需要经过两次,相当于这条边变回去了
我们这时把第一次求出的直径上的边边权取相反数
再求一遍直径,需要注意的是这一次一定要用 dp 求树的直径,因为有负边权了, bfs 第一次找最深的节点的性质不满足,会搞炸,可以手画样例理解一下
设两次求出的直径分别为 l1, l2
则答案为 (2(n - 1) - (l1 - 1) - (l2 - 1)) = 2n - l1 - l2
以下 copy 自题解
如果 直径l2 包含 直径l1 的部分
当减掉 (l1 - 1) 后,重叠的部分变为“只需要经过一次”,减掉 (l2 - 1) 后,重叠的部分就变回了“需要经过两次”,这正是我们想要的
代码:
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<cstdio>
#include<queue>
using namespace std;
const int MAXN = 100005;
struct EDGE{
int nxt, to, val;
EDGE(int NXT = 0, int TO = 0, int VAL = 0) {nxt = NXT; to = TO; val = VAL;}
}edge[MAXN << 1];
int n, k, totedge = 1, mxpt, p, q, l1, l2 = 0xcfcfcfcf;
int head[MAXN], dst[MAXN], fa[MAXN], frm[MAXN];
bool vis[MAXN];
inline void add(int x, int y, int v) {
edge[++totedge] = EDGE(head[x], y, v);
head[x] = totedge;
return;
}
void dfs(int x) {
vis[x] = true;
for(int i = head[x]; i; i = edge[i].nxt) if(!vis[edge[i].to]) {
int y = edge[i].to;
dfs(y);
l2 = max(l2, dst[x] + dst[y] + edge[i].val);
dst[x] = max(dst[x], dst[y] + edge[i].val);
}
return;
}
inline void bfs(int bgn) {
for(int i = 1; i <= n; ++i) dst[i] = 0x3f3f3f3f, fa[i] = 0, frm[i] = 0;
queue<int> que;
dst[bgn] = 0;
frm[bgn] = 0;
que.push(bgn);
while(!que.empty()) {
int x = que.front(); que.pop();
if(dst[x] > dst[mxpt]) mxpt = x;
for(int i = head[x]; i; i = edge[i].nxt) if(dst[edge[i].to] == 0x3f3f3f3f) {
int y = edge[i].to;
dst[y] = dst[x] + edge[i].val;
frm[y] = i;
fa[y] = x;
que.push(y);
}
}
return;
}
int main() {
scanf("%d%d", &n, &k);
register int xx, yy;
for(int i = 1; i < n; ++i) {
scanf("%d%d", &xx, &yy);
add(xx, yy, 1); add(yy, xx, 1);
}
dst[0] = 0xcfcfcfcf;
bfs(1); p = mxpt; mxpt = 0;
bfs(p); q = mxpt;
while(mxpt != p) {
++l1;
edge[frm[mxpt]].val = edge[frm[mxpt] ^ 1].val = -1;
mxpt = fa[mxpt];
}
if(k == 1) {
printf("%d
", (n << 1) - l1 - 1);
return 0;
}
if(k == 2) {
for(int i = 1; i <= n; ++i) dst[i] = 0;
dfs(1);
if(l2 < 0) l2 = 0;
printf("%d
", (n << 1) - l1 - l2);
}
return 0;
}