题目:
题解:
首先我们考虑$K = 1$的情况,由于原图是一棵树, 所以每一条边都要走两次, 即总距离为$2 * ( n - 1)$
如果我们添加一条边, 就会形成一个环, 并且由于必须经过添加的边一次, 所以组成这个环的边都只经过一次。
所以添边后的总距离为$2 * (n - 1) - $环长$ + 2$。
我们只需找出最长的链即树的直径$L$, 并把边添到两端点上, 答案就是 $2 * (n - 1) - L + 1$。
树的直径可以通过bfs 或 dp 求出。
接着考虑$K = 2$的情况
我们将直径及直径的两端点求出, 把答案减去 $(L_1 - 1)$, 并把直径上的边权都改为-1.
然后再求一遍直径$L_2$, 并把答案减去$(L_2 - 1)$ 。
这样两个环重叠的部分就会被第二次加上去,即经过了两次,而环不重合的部分仍然经过一次。
注意点: 第一次求直径需要用bfs或dfs求出端点和长度, 第二次由于有负权, 用bfs或dfs处理会错误,需要用dp做
更新, 发现自己的flag好像倒了??
发现一个博客可以用两次dfs求出树的直径与端点:万能的传送门
代码
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #include<queue> 5 #define rd read() 6 #define rep(i,a,b) for(int i = (a); i <= (b); ++i) 7 #define per(i,a,b) for(int i = (a); i >= (b); --i) 8 #define cl(a) memset(a, 0, sizeof(a)) 9 using namespace std; 10 11 const int N = 2e5; 12 const int inf = -2139062144; 13 14 int dis[N], disp[N], vis[N]; 15 int head[N], tot, ans, d[N]; 16 int n, K, maxn; 17 18 queue<int> q; 19 20 struct edge { 21 int nxt, to, val; 22 }e[N << 1]; 23 24 int read() { 25 int X = 0, p = 1; char c = getchar(); 26 for(; c > '9' || c < '0'; c = getchar()) if(c == '-') p = -1; 27 for(; c >= '0' && c <= '9'; c = getchar()) X = X * 10 + c - '0'; 28 return X * p; 29 } 30 31 void added(int u, int v) { 32 e[++tot].to = v; 33 e[tot].nxt = head[u]; 34 e[tot].val = 1; 35 head[u] = tot; 36 } 37 38 void add(int u, int v) { 39 added(u, v); added(v, u); 40 } 41 42 int ch(int x) { 43 return ((x + 1) ^ 1) - 1; 44 } 45 46 void bfs(int x) { 47 cl(vis); 48 vis[x] = 1; 49 dis[x] = 0; 50 q.push(x); 51 for(int u; !q.empty(); ) { 52 u = q.front(); q.pop(); 53 for(int i = head[u]; i; i = e[i].nxt) { 54 int nt = e[i].to; 55 if(vis[nt]) continue; 56 dis[nt] = e[i].val + dis[u]; 57 vis[nt] = 1; 58 q.push(nt); 59 } 60 } 61 } 62 63 int dfs(int u, int T) { 64 vis[u] = 1; 65 if(u == T) return 1; 66 for(int i = head[u]; i; i = e[i].nxt) { 67 int nt = e[i].to; 68 if(vis[nt]) continue; 69 if(dfs(nt, T)) { 70 e[i].val = -e[i].val; 71 e[ch(i)].val = -e[ch(i)].val; 72 return 1; 73 } 74 } 75 return 0; 76 } 77 78 void dp(int u, int fa) { 79 for(int i = head[u]; i; i = e[i].nxt) { 80 int nt = e[i].to; 81 if(nt == fa) continue; 82 dp(nt, u); 83 maxn = max(maxn, d[u] + d[nt] + e[i].val); 84 d[u] = max(d[u], d[nt] + e[i].val); 85 } 86 if(d[u] == inf) d[u] = 0; 87 } 88 89 int main() 90 { 91 n = rd; K = rd; 92 rep(i, 1, n - 1) { 93 int u = rd, v = rd; 94 add(u, v); 95 } 96 ans = 2 * (n - 1); 97 dis[0] = -1e8; 98 bfs(1); 99 int pos = 0, pos2 = 0; 100 rep(i, 1, n) if(dis[i] >= dis[pos]) pos = i; 101 bfs(pos); 102 rep(i, 1, n) if(dis[i] >= dis[pos2]) pos2 = i; 103 ans -= dis[pos2] - 1; 104 cl(vis); dfs(pos, pos2); 105 if(K == 1) return printf("%d ", ans), 0; 106 memset(d, 128, sizeof(d)); dp(1, 0); 107 ans -= maxn - 1; 108 printf("%d ", ans); 109 }