题目描述
题目
需要的最短时间,明显二分
判断答案是否可行只要把超过答案的路径都记下来,找到一条所有超过的答案路径都经过的边,尝试删掉它,如果最长的路减去它小于答案,那么此答案就是可行的解
至于统计所有路径都经过的边,差分统计一下就好
经过running的折磨,感觉transport突然变简单了
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
using namespace std;
const int N=300010;
int n, m, lca[N], a[N], b[N];
int he[N], ne, hq[N], nq, rt;
struct E {int to, next, w;} e[N<<1];
void build (int u, int v, int w) {e[ne]=(E){v,he[u],w}; he[u]=ne++; e[ne]=(E){u,he[v],w}; he[v]=ne++;}
struct Q{int to, next, flag, idx;} q[N<<1];
void add(int u, int v, int m) {q[nq]=(Q){v,hq[u],0,m}; hq[u]=nq++; q[nq]=(Q){u,hq[v],0,m}; hq[v]=nq++;}
int f[N],vis[N],dep[N],dis[N],pre[N];
int find(int v) {return v == f[v] ? v : f[v]=find(f[v]);}
void tarjan (int u, int fa)
{
int v; vis[u]=1; dep[u]=dep[fa]+1; f[u]=u;
for(int i=he[u]; i != -1; i=e[i].next)
{
if((v=e[i].to) == fa) continue;
dis[v]=dis[u]+e[i].w; pre[v]=e[i].w;
//printf("%d %d %d
",v,dis[u],dis[v]);
tarjan(v, u); f[v]=u;
}
for(int i=hq[u]; i != -1; i=q[i].next)
{
if(!vis[v=q[i].to] || q[i].flag) continue;
q[i].flag=q[i^1].flag=1;
lca[q[i].idx]=find(v);
//printf("%d %d
",q[i].idx,lca[q[i].idx]);
}
}
int len[N],mark[N],maxm;
void pushup(int u, int fa)
{
int v;
for(int i=he[u]; i != -1; i=e[i].next)
{
if((v=e[i].to) == fa) continue;
pushup(v,u);
mark[u]+=mark[v];
}
}
int check(int k)
{
memset(mark,0,sizeof(mark));
int cnt=0;
for(int i=1; i <= m; i++)
if(len[i] > k)
{
cnt++;
mark[a[i]]++,mark[b[i]]++,mark[lca[i]]-=2;
}
pushup(rt,0);
for(int i=1; i <= n; i++)
if(maxm-pre[i] <= k && mark[i] == cnt) return 1;
return 0;
}
void solve()
{
tarjan(rt,0);int r=0,l=0;
for(int i=1; i<= m; i++)
{
len[i]=dis[a[i]]+dis[b[i]]-(dis[lca[i]]<<1);
if(len[i] > r) maxm=r=len[i];
}
int ans;
while(l <= r)
{
int mid=(l+r)>>1;
if(check(mid)) r=mid-1,ans=mid;
else l=mid+1;
}
printf("%d
",ans);
}
int read(){
int out=0;char c=getchar();while(c > '9' || c < '0') c=getchar();
while(c >= '0' && c <= '9') {out=(out<<1)+(out<<3)+c-'0';c=getchar();}
return out;
}
int siz[N],mind=N;
void dfs(int u, int fa)
{
siz[u]=1; int minn=N,maxn=-N,v;
for(int i=he[u]; i != -1; i=e[i].next)
{
if((v=e[i].to) == fa) continue;
dfs(v,u);
siz[u]+=siz[v];
if(minn > siz[v]) minn=siz[v];
}
if(minn == N) return ;
if(minn > n-siz[u] && fa) minn=n-siz[u];
if(maxn < n-siz[u]) maxn=n-siz[u];
if(maxn-minn < mind) rt=u,mind=maxn-minn;
}
void init()
{
memset(he,-1,sizeof(he));memset(hq,-1,sizeof(hq));
n=read(),m=read();int u,v,w;
for(int i=1; i < n; i++) u=read(),v=read(),w=read(),build(u,v,w);
for(int i=1; i <= m; i++) a[i]=read(),b[i]=read(),add(a[i],b[i],i);
dfs(1,0);
}
int main()
{
init();solve();
return 0;
}