[题目链接]
https://www.lydsy.com/JudgeOnline/problem.php?id=4326
[算法]
首先,此题的答案是具有单调性的,因此可以二分答案mid
检验答案时,我们判断每条路径的长度是否大于mid,若大于mid,则说明至少要将这条路径上的一条边变为“虫洞”
因此,我们可以对所有长度大于mid的路径做树上差分,若一条边差分后的值 = 大于mid的路径总数,那么判断最长路径 - 这条边的长度 <= mid
时间复杂度 : O(N log LEN)( LEN为每条边的长度之和 )
[代码]
此份代码在BZOJ上通过了所有测试点,但在UOJ上由于常数原因只能拿到95分 , 此题需要一些常数优化
#include<bits/stdc++.h> using namespace std; #define MAXN 300010 #define MAXLOG 18 struct edge { int to,w,nxt; } e[MAXN << 1]; int i,n,m,tot,l,r,mid,ans,cnt,maxlen,num; int u[MAXN],v[MAXN],a[MAXN],b[MAXN],t[MAXN],head[MAXN],depth[MAXN],sum[MAXN],s[MAXN],len[MAXN],dis[MAXN]; int anc[MAXN][MAXLOG]; namespace IO { template <typename T> inline void read(T &x) { int f = 1; x = 0; char c = getchar(); for (; !isdigit(c); c = getchar()) { if (c == '-') f = -f; } for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0'; x *= f; } template <typename T> inline void write(T x) { if (x < 0) { putchar('-'); x = -x; } if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> inline void writeln(T x) { write(x); puts(""); } } ; inline void addedge(int u,int v,int w) { tot++; e[tot] = (edge){v,w,head[u]}; head[u] = tot; } inline void dfs(int u) { int i,v,w; for (i = 1; i < MAXLOG; i++) { if (depth[u] < (1 << i)) break; anc[u][i] = anc[anc[u][i - 1]][i - 1]; } for (i = head[u]; i; i = e[i].nxt) { v = e[i].to; w = e[i].w; if (v != anc[u][0]) { len[v] = w; anc[v][0] = u; depth[v] = depth[u] + 1; sum[v] = sum[u] + w; dfs(v); } } } inline void calc(int u) { int i,v; for (i = head[u]; i; i = e[i].nxt) { v = e[i].to; if (v != anc[u][0]) { calc(v); s[u] += s[v]; } } if (s[u] == num && len[u] > maxlen) maxlen = len[u]; } inline int lca(int x,int y) { int i,t; if (depth[x] > depth[y]) swap(x,y); t = depth[y] - depth[x]; for (i = 0; i < MAXLOG; i++) { if (t & (1 << i)) y = anc[y][i]; } if (x == y) return x; for (i = MAXLOG - 1; i >= 0; i--) { if (anc[x][i] != anc[y][i]) { x = anc[x][i]; y = anc[y][i]; } } return anc[x][0]; } inline int dist(int x,int y) { return sum[x] + sum[y] - 2 * sum[lca(x,y)]; } inline bool check(int mid) { int i,j,p,q,w,mx = 0; num = 0; for (i = 1; i <= n; i++) s[i] = 0; for (i = 1; i <= m; i++) { if (dis[i] > mid) { mx = max(mx,dis[i]); num++; s[u[i]]++; s[v[i]]++; s[lca(u[i],v[i])] -= 2; } } if (cnt == 0) return true; maxlen = 0; calc(1); return mx - maxlen <= mid; } int main() { IO :: read(n); IO :: read(m); for (i = 1; i < n; i++) { IO :: read(a[i]); IO :: read(b[i]); IO :: read(t[i]); addedge(a[i],b[i],t[i]); addedge(b[i],a[i],t[i]); cnt += t[i]; } dfs(1); for (i = 1; i <= m; i++) { IO :: read(u[i]); IO :: read(v[i]); dis[i] = dist(u[i],v[i]); } l = 0; r = cnt; while (l <= r) { mid = (l + r) >> 1; if (check(mid)) { r = mid - 1; ans = mid; } else l = mid + 1; } IO :: writeln(ans); return 0; }