大概就是二分+树上差分...
题意:给你树上m条路径,你要把一条边权变为0,使最长的路径最短。
最大的最小,看出二分(事实上我并没有看出来...)
然后二分k,对于所有大于k的边,树上差分求出最长公共边,然后看是否可以。
(yy的解法②:边按照长度排序,然后二分。删除最长公共边。据logeadd juru说是三分)
代码量3.6k,180行,还是有点长的。
1 #include <cstdio> 2 #include <algorithm> 3 #include <cstring> 4 const int N = 300010; 5 6 inline void read(int &x) { 7 char c = getchar(); 8 x = 0; 9 while(c > '9' || c < '0') { 10 c = getchar(); 11 } 12 while(c <= '9' && c >= '0') { 13 x = (x << 3) + (x << 1) + c - 48; 14 c = getchar(); 15 } 16 return; 17 } 18 19 struct Edge { 20 int v, nex, len; 21 }edge[N << 1]; int top; 22 23 int e[N], n, m, lm, fa[N][20], d[N], lenth[N]; /// 点 24 int l[N], r[N], mid[N], len[N]; /// 路径 25 bool use[N]; /// 树上差分 26 int num, large, R, f[N]; 27 28 inline void add(int x, int y, int z) { 29 edge[++top].v = y; 30 edge[top].len = z; 31 edge[top].nex = e[x]; 32 e[x] = top; 33 return; 34 } 35 36 inline void DFS1(int x, int f) { 37 fa[x][0] = f; 38 for(int i = e[x]; i; i = edge[i].nex) { 39 int y = edge[i].v; 40 if(y != f) { 41 lenth[y] = lenth[x] + edge[i].len; 42 d[y] = d[x] + 1; 43 DFS1(y, x); 44 } 45 } 46 return; 47 } 48 49 inline void getlca() { 50 while((1 << lm) < n) { 51 lm++; 52 } 53 DFS1(1, 0); 54 for(int i = 1; i <= lm; i++) { 55 for(int x = 1; x <= n; x++) { 56 fa[x][i] = fa[fa[x][i - 1]][i - 1]; 57 } 58 } 59 return; 60 } 61 62 inline int lca(int x, int y) { 63 if(d[x] > d[y]) { 64 std::swap(x, y); 65 } 66 int t = lm; 67 while(t > -1 && d[y] > d[x]) { 68 if(d[fa[y][t]] >= d[x]) { 69 y = fa[y][t]; 70 } 71 t--; 72 } 73 if(x == y) { 74 return x; 75 } 76 t = lm; 77 while(t > -1 && fa[x][0] != fa[y][0]) { 78 if(fa[x][t] != fa[y][t]) { 79 x = fa[x][t]; 80 y = fa[y][t]; 81 } 82 t--; 83 } 84 return fa[x][0]; 85 } 86 87 inline int DFS(int x) { 88 int cnt = 0; 89 for(int i = e[x]; i; i = edge[i].nex) { 90 int y = edge[i].v; 91 if(y == fa[x][0]) { 92 continue; 93 } 94 int temp = DFS(y); 95 cnt += temp; 96 if(temp == num) { 97 large = std::max(large, edge[i].len); 98 } 99 } 100 cnt += f[x]; 101 return cnt; 102 } 103 104 inline bool check(int k) { 105 num = 0; 106 memset(f, 0, sizeof(f)); 107 for(int i = 1; i <= m; i++) { 108 bool t = len[i] > k; 109 use[i] = t; 110 num += t; 111 if(t) { 112 f[l[i]]++; 113 f[r[i]]++; 114 f[mid[i]] -= 2; 115 } 116 } 117 large = 0; 118 DFS(1); 119 return R - large <= k; 120 } 121 122 inline int getlong(int i) { 123 int x = l[i]; 124 int ans = 0; 125 while(x != mid[i]) { 126 ans = std::max(ans, lenth[x] - lenth[fa[x][0]]); 127 x = fa[x][0]; 128 } 129 x = r[i]; 130 while(x != mid[i]) { 131 ans = std::max(ans, lenth[x] - lenth[fa[x][0]]); 132 x = fa[x][0]; 133 } 134 return ans; 135 } 136 137 int main() { 138 scanf("%d%d", &n, &m); 139 int x, y, z; 140 for(int i = 1; i < n; i++) { 141 //scanf("%d%d%d", &x, &y, &z); 142 read(x); 143 read(y); 144 read(z); 145 add(x, y, z); 146 add(y, x, z); 147 } 148 getlca(); 149 int dr = 0, dl = 0, dm, A = 1; 150 for(int i = 1; i <= m; i++) { 151 //scanf("%d%d", &l[i], &r[i]); 152 read(l[i]); 153 read(r[i]); 154 mid[i] = lca(l[i], r[i]); 155 len[i] = lenth[l[i]] + lenth[r[i]] - 2 * lenth[mid[i]]; 156 if(len[i] > dr) { 157 dr = len[i]; 158 A = i; 159 } 160 } 161 R = dr; 162 dl = dr - getlong(A); 163 if(dl < 0) { 164 printf("ERROR "); 165 } 166 while(dl < dr) { 167 dm = (dr + dl) / 2; 168 if(check(dm)) { 169 //printf("check %d 1 ", dm); 170 dr = dm; 171 } 172 else { 173 //printf("check %d 0 ", dm); 174 dl = dm + 1; 175 } 176 } 177 printf("%d", dr); 178 return 0; 179 }