题意
题目链接:https://ac.nowcoder.com/acm/contest/6885/F
分析
代码
#include <bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
typedef pair<int,int>pii;
const int N=3e5+5;
const int mak=25;
vector<pii>pic[N];
int depth[N],f[N<<1][mak],rnk[N];
ll dis[N];//到根的距离
pii st[N][mak];
int n,cnt;
//用倍增求lca会多一个log
void read(int &x)
{
x=0;
int f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=10*x+ch-'0';
ch=getchar();
}
x*=f;
}
void dfs(int u,int p,int d)//必须要用dfs序
{
depth[u]=d;
rnk[u]=++cnt;
f[cnt][0]=u;
for(int i=0;i<pic[u].size();i++)
{
int v=pic[u][i].second;
if(v==p) continue;
dis[v]=dis[u]+pic[u][i].first;
dfs(v,u,d+1);
f[++cnt][0]=u;
}
}
void init()
{
dfs(1,0,0);
int len=(int)log2(1.0*cnt);
for(int k=1;k<=len;k++)
{
for(int i=1;i+(1<<k)-1<=cnt;i++)
{
int a=f[i][k-1];
int b=f[i+(1<<(k-1))][k-1];
if(depth[a]<depth[b]) f[i][k]=a;
else f[i][k]=b;
}
}
}
int RMQ(int l,int r)//l,r为dfs序
{
int len=(int)log2(1.0*(r-l+1));
int a=f[l][len],b=f[r-(1<<len)+1][len];
if(depth[a]<depth[b]) return a;
else return b;
}
int lca(int u,int v)
{
int l=rnk[u],r=rnk[v];
if(l>r) swap(l,r);
return RMQ(l,r);
}
ll get_dis(int u,int v)
{
int t=lca(u,v);
return dis[u]+dis[v]-2*dis[t];
}
void solve()
{
for(int i=1;i<=n;i++)//点的编号
st[i][0]=make_pair(i,i);
int len=(int)log2(1.0*n);
for(int k=1;k<=len;k++)
{
for(int i=1;i+(1<<k)-1<=n;i++)
{
int a=st[i][k-1].first,b=st[i][k-1].second;
int c=st[i+(1<<(k-1))][k-1].first,d=st[i+(1<<(k-1))][k-1].second;
ll d1=get_dis(a,b);
ll d2=get_dis(c,d);
ll d3=get_dis(a,c);
ll d4=get_dis(a,d);
ll d5=get_dis(b,c);
ll d6=get_dis(b,d);
ll maxn=max(max(d1,d2),max(d3,d4));
maxn=max(max(d5,d6),maxn);
if(maxn==d1)
st[i][k]=make_pair(a,b);
else if(maxn==d2)
st[i][k]=make_pair(c,d);
else if(maxn==d3)
st[i][k]=make_pair(a,c);
else if(maxn==d4)
st[i][k]=make_pair(a,d);
else if(maxn==d5)
st[i][k]=make_pair(b,c);
else st[i][k]=make_pair(b,d);
}
}
}
ll cal(int l,int r)
{
int len=(int)log2(1.0*(r-l+1));
pii a=st[l][len],b=st[r-(1<<len)+1][len];
ll maxn=0;
maxn=max(maxn,get_dis(a.first,b.first));
maxn=max(maxn,get_dis(a.first,a.second));
maxn=max(maxn,get_dis(a.first,b.second));
maxn=max(maxn,get_dis(b.first,b.second));
maxn=max(maxn,get_dis(a.second,b.first));
maxn=max(maxn,get_dis(a.second,b.second));
return maxn;
}
int main()
{
int q;
read(n),read(q);
int u,v,w,l,r;
for(int i=1;i<n;i++)
{
read(u),read(v),read(w);
pic[u].pb(make_pair(w,v));
pic[v].pb(make_pair(w,u));
}
init();
solve();
while(q--)
{
read(l),read(r);
printf("%lld
",cal(l,r));
}
return 0;
}