好久没更新博客了,更一篇吧(qwq)
题目链接
思路
要求我们让路径的最大时间最小
这很二分答案
可以二分答案(mid),然后想办法(O(n))去检查答案是否合法
可以记录出路径长度大于(mid)的路径,尽量在这些路径的交集部分建造黑洞(显而易见),我们可以用边差分(diff[i])来记录这条边被几个大于(mid)的路径包含,假设一共有(qwq)个大于(mid)的路径,那么交集部分就是(diff[i]==qwq)的
/*
@ author:pyyyyyy
-----思路------
二分+树上差分
-----debug-------
*/
#include<bits/stdc++.h>
using namespace std;
const int N=300010;
struct node
{
int v,Next,w;
}e[N<<1];
int cnt,head[N];
void add(int u,int v,int w)
{
e[++cnt].Next=head[u];
head[u]=cnt;
e[cnt].v=v;
e[cnt].w=w;
}
int dis[N],val[N];
//-------------------------
int fa[N],son[N],size[N],top[N],dep[N];
void dfs1(int u,int Fa)
{
size[u]=1;fa[u]=Fa;
dep[u]=dep[Fa]+1;
for(int i=head[u];i;i=e[i].Next)
{
int to=e[i].v;
if(to==Fa) continue;
dis[to]=dis[u]+e[i].w;//应该放在这里
val[to]=e[i].w;
dfs1(to,u);
// dis[to]=dis[u]+e[i].w;
// val[to]=e[i].w;
size[u]+=size[to];
if(size[son[u]]<size[to]) son[u]=to;
}
}
void dfs2(int u,int Top)
{
top[u]=Top;
if(son[u]) dfs2(son[u],Top);
else return ;
for(int i=head[u];i;i=e[i].Next)
{
int to=e[i].v;
if(to!=fa[u]&&to!=son[u]) dfs2(to,to);
}
}
int lca(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]<dep[y]) return x;
else return y;
}
//--------------------------
int n,m,l,r;
int ans,diff[N];
struct road
{
int u,v,lca,dis;
}p[N];
int cmp(road xx,road yy)
{
return xx.dis>yy.dis;
}
void dfs3(int u,int fa)
{
for(int i=head[u];i;i=e[i].Next){
int to=e[i].v;
if(to==fa) continue;
dfs3(to,u);
diff[u]+=diff[to];
}
}
int check(int x)
{
int js=0,maxn=0;
memset(diff,0,sizeof(diff));
for(int i=1;i<=m;++i)
{
if(p[i].dis<=x) break;
diff[p[i].u]++;
diff[p[i].v]++;
diff[p[i].lca]-=2;
js++;
}
dfs3(1,0);
for(int i=1;i<=n;++i)
{
if(diff[i]==js)
maxn=max(maxn,val[i]);
}
return p[1].dis-maxn<=x;
}
signed main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
cin>>n>>m;
for(int i=1;i<n;++i)
{
int u,v,w;
scanf("%d %d %d",&u,&v,&w);
l=max(l,w);
add(u,v,w);add(v,u,w);
}
dfs1(1,0);dfs2(1,1);
// for(int i=1;i<=m;++i) cout<<dis[i]<<' ';
for(int i=1;i<=m;++i)
{
int u,v;
scanf("%d %d",&u,&v);
p[i].u=u;p[i].v=v;
p[i].lca=lca(u,v);
// cout<<p[i].lca<<'
';
p[i].dis=dis[u]+dis[v]-(dis[p[i].lca]<<1);
r=max(r,p[i].dis);
}
// for(int i=1;i<=m;++i) cout<<p[i].dis<<' ';
sort(p+1,p+1+m,cmp);
l=r-l;
while(l<=r)
{
int mid=(l+r)>>1;
if(check(mid))
{
ans=mid;
r=mid-1;
}
else l=mid+1;
}
cout<<ans;
return 0;
}