问题描述
给定一棵n个点的树,以及m条路径,每次询问第L条到第R条路径的交集部分的长度(如果一条边同时出现在2条路径上,那么它属于路径的交集)。
输入格式
第一行一个数n(n<=500,000)
接下来n-1行,每行三个数x,y,z,表示一条从x到y并且长度为z的边
第n+1行一个数m(m<=500,000)
接下来m行,每行两个数u,v,表示一条从u到v的路径
接下来一行一个数Q,表示询问次数(Q<=500,000)
接下来Q行,每行两个数L和R
输出格式
Q行,每行一个数表示答案
样例输入
4
1 2 5
2 3 2
1 4 3
2
1 2
3 4
1
1 2
样例输出
5
解析
按照常规思想,这也许可以通过一些树上路径相关的数据结构来完成。但是实现起来有很大的困难。
看到询问的方式,与线段树区间询问的类型很相似。不妨也用类似的方法,用线段树维护路径的交集。具体维护起来可以通过维护交集路径的两个端点,加上一些分类讨论完成。为了保证复杂度是(O(nlogn))的,需要用ST表实现最近公共祖先。
代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#define N 500002
using namespace std;
const int inf=1<<30;
struct path{
int u,v,dis;
}a[N],t[N*4];
int head[N],ver[N*2],nxt[N*2],edge[N*2],l;
int n,m,q,i,st[N][30],dep[N],dis[N],in[N],s[N],top;
int read()
{
char c=getchar();
int w=0;
while(c<'0'||c>'9') c=getchar();
while(c<='9'&&c>='0'){
w=w*10+c-'0';
c=getchar();
}
return w;
}
void insert(int x,int y,int z)
{
l++;
ver[l]=y;
edge[l]=z;
nxt[l]=head[x];
head[x]=l;
}
void dfs(int x,int pre)
{
s[++top]=x;
dep[top]=dep[in[pre]]+1;
in[x]=top;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y!=pre){
dis[y]=dis[x]+edge[i];
dfs(y,x);
s[++top]=x,dep[top]=dep[in[x]];
}
}
}
int get(int x,int y)
{
if(dep[x]>dep[y]) return y;
return x;
}
void init()
{
dfs(1,0);
for(int i=1;i<=top;i++) st[i][0]=i;
for(int j=0;(1<<(j+1))<=top;j++){
for(int i=1;i+(1<<(j+1))-1<=top;i++) st[i][j+1]=get(st[i][j],st[i+(1<<j)][j]);
}
}
int LCA(int u,int v)
{
int l=in[u],r=in[v];
if(l>r) swap(l,r);
int k=log2(1.0*(r-l+1));
return s[get(st[l][k],st[r-(1<<k)+1][k])];
}
int dist(int u,int v)
{
return dis[u]+dis[v]-2*dis[LCA(u,v)];
}
int my_comp(const int &x,const int &y)
{
return dep[in[x]]>dep[in[y]];
}
path update(path x,path y)
{
if(x.dis==inf) return y;
if(y.dis==inf) return x;
if(x.u==y.u&&x.v==y.v) return x;
if(x.dis==0||y.dis==0) return (path){0,0,0};
int a[4];
a[0]=LCA(x.u,y.u);
a[1]=LCA(x.u,y.v);
a[2]=LCA(x.v,y.u);
a[3]=LCA(x.v,y.v);
sort(a,a+4,my_comp);
int p1=a[0],p2=a[1];
if(p1==p2&&((dep[in[p1]]<dep[in[LCA(x.u,x.v)]])||(dep[in[p2]]<dep[in[LCA(y.u,y.v)]]))) return (path){0,0,0};
return (path){p1,p2,dist(p1,p2)};
}
void build(int p,int l,int r)
{
if(l==r){
t[p].u=a[l].u,t[p].v=a[l].v;
t[p].dis=dist(t[p].u,t[p].v);
return;
}
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
t[p]=update(t[p*2],t[p*2+1]);
}
path ask(int p,int l,int r,int ql,int qr)
{
if(ql<=l&&r<=qr) return t[p];
int mid=(l+r)/2;
path ansl=(path){0,0,inf},ansr=(path){0,0,inf};
if(ql<=mid) ansl=ask(p*2,l,mid,ql,qr);
if(qr>mid) ansr=ask(p*2+1,mid+1,r,ql,qr);
return update(ansl,ansr);
}
int main()
{
n=read();
for(i=1;i<n;i++){
int u=read(),v=read(),w=read();
insert(u,v,w);
insert(v,u,w);
}
init();
m=read();
for(i=1;i<=m;i++) a[i].u=read(),a[i].v=read();
build(1,1,m);
q=read();
for(i=1;i<=q;i++){
int l=read(),r=read();
path ans=ask(1,1,m,l,r);
printf("%d
",ans.dis);
}
return 0;
}