题意:
给你一棵n个结点的树,有m个运输计划,每个计划表示从一个点x到一个点y的路径长度,你可以将一条边的长度赋为0,问完成所有计划的最短时间。
题解:
再写一遍了;
要你扣掉一条边,直接扣掉再计算答案至少要(O(nm))的复杂度,再加上这题常数比较大,最多50分吧;
那么二分答案转化问题,二分完成所有任务的时间,若扣掉一条边后用时最长的任务的时间小于mid,则返回true;
对于那些大于mid的任务,考虑扣掉一条边,应为我们二分了完成所有任务的时间,所以扣掉的这条边一定是要被所有的任务路径经过,并且越大越好,于是就要用到树上差分找到这条边,一遍dfs即可;
复杂度(O((n+m)log1000))
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define ll long long
#define N 300010
using namespace std;
int n,m,L,e_num,mx,mw,num,ans=1<<30;
int nxt[N<<1],to[N<<1],w[N<<1],h[N];
int fa[N],dep[N],top[N],siz[N],son[N],dist[N],cnt[N];
struct Node {int x,y,z;}task[N];
inline int gi() {
int x=0,o=1; char ch=getchar();
while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
if(ch=='-') o=-1,ch=getchar();
while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
return o*x;
}
inline void add(int x, int y, int z) {
nxt[++e_num]=h[x],to[e_num]=y,w[e_num]=z,h[x]=e_num;
}
inline void dfs1(int u) {
siz[u]=1;
for(int i=h[u]; i; i=nxt[i]) {
int v=to[i];
if(v==fa[u]) continue;
fa[v]=u,dep[v]=dep[u]+1,dist[v]=dist[u]+w[i];
dfs1(v);
if(siz[v]>siz[son[u]]) son[u]=v;
siz[u]+=siz[v];
}
}
inline void dfs2(int u) {
if(son[u]) top[son[u]]=top[u],dfs2(son[u]);
for(int i=h[u]; i; i=nxt[i]) {
int v=to[i];
if(v==fa[u] || v==son[u]) continue;
top[v]=v,dfs2(v);
}
}
inline void dfs3(int u) {
for(int i=h[u]; i; i=nxt[i]) {
int v=to[i];
if(v==fa[u]) continue;
dfs3(v);
if(cnt[v]==num) mw=max(mw,w[i]);
cnt[u]+=cnt[v];
}
}
inline int lca(int x, int y) {
while(top[x]!=top[y]) {
if(dep[top[x]]>dep[top[y]]) x=fa[top[x]];
else y=fa[top[y]];
}
if(dep[x]<dep[y]) return x;
else return y;
}
inline bool check(int mid) {
mx=mw=num=0;
for(int i=1; i<=n; i++) cnt[i]=0;
for(int i=1; i<=m; i++) {
if(task[i].z>mid) {
cnt[task[i].x]++,cnt[task[i].y]++,cnt[lca(task[i].x,task[i].y)]-=2;
mx=max(mx,task[i].z),num++;
}
}
dfs3(1);
return mx-mw<=mid;
}
int main() {
n=gi(),m=gi();
for(int i=1; i<n; i++) {
int x=gi(),y=gi(),z=gi();
add(x,y,z),add(y,x,z);
}
fa[1]=1,dep[1]=1,top[1]=1;
dfs1(1),dfs2(1);
lca(4,5);
for(int i=1; i<=m; i++) {
int x=gi(),y=gi(),z=dist[x]+dist[y]-dist[lca(x,y)]*2;
task[i]=(Node){x,y,z};
L=max(L,z);
}
int r=L,l=r-1000,mid;
while(l<=r) {
mid=(l+r)>>1;
if(check(mid)) ans=min(ans,mid),r=mid-1;
else l=mid+1;
}
printf("%d", ans);
return 0;
}