题目让我们维护这么一个东西:
y
i
+
q
j
x
i
+
p
j
dfrac{y_i+q_j}{x_i+p_j}
xi+pjyi+qj 的最大值。
容易想到分数规划,二分枚举答案 m i d mid mid,则有: y i + q j x i + p j = m i d dfrac{y_i+q_j}{x_i+p_j}=mid xi+pjyi+qj=mid
化简: y i + q j = m i d × ( x i + p j ) y_i+q_j=mid imes(x_i+p_j) yi+qj=mid×(xi+pj)
移项得: ( y i − m i d × x i ) + ( q j − m i d × p j ) = 0 (y_i-mid imes x_i)+(q_j-mid imes p_j)=0 (yi−mid×xi)+(qj−mid×pj)=0
现在 i i i 和 j j j 已经没有关系了,所以我们只需要分别找到 y i − m i d × x i y_i-mid imes x_i yi−mid×xi 和 q j − m i d × p j q_j-mid imes p_j qj−mid×pj 的最大值,然后看加起来是否大于等于 0 0 0 就好了。
但是这个东西怎么维护呢?我们就以 y i − m i d × x i y_i-mid imes x_i yi−mid×xi 为例。发现唯一在变的东西就是 m i d mid mid,然后 x i x_i xi、 y i y_i yi 是题目已经给出的。不难想到用斜率优化来维护这个东西。
也就是说,设函数 f ( a ) = − x i × a + y i f(a)=-x_i imes a+y_i f(a)=−xi×a+yi,把 a a a 当为自变量,那么每一个树上的点都对应着一条直线。那么我们要维护的就是在询问点 u u u、 v v v 间的路径上的所有点对应的函数中,当 a = m i d a=mid a=mid 时的最大值。
比如说:
(注:由于所有直线的斜率都是
−
x
i
-x_i
−xi,所以这些直线都是单调下降的)
假如当前的 m i d = 2 mid=2 mid=2,然后在 u → v u ightarrow v u→v 这条路径上,所有点所对应的直线如图所示。那么对于这堆直线,当 a = 2 a=2 a=2 时,显然取 A A A 点时最大。
考虑怎么用树剖维护这个东西。
我们发现,对于一些直线,我们只用维护它们形成的最高的那一条折线就好了。文字不好解释,看图:
对于上面的那一幅图,这一幅图中的蓝线就是我所说的最高的那条折线。
看到这副图就能联想到用半平面交维护,因为蓝色折线所圈出来的就是这些直线的半平面交。
那么做法就显而易见了:
我们对线段树的每一个节点建一个 vector<Line>line
,存储构成这条蓝色折线的所有线段。并且满足这些线段从左到右排序。当然,也易证这些线段也满足按斜率排序的要求(等会要用)。
然后考虑合并。设线段树的当前节点为
u
u
u,左右儿子分别为
l
c
lc
lc、
r
c
rc
rc。我们要做的就是把 line[lc]
和 line[rc]
扔进一个数组里面,然后做半平面交,再更新 line[u]
。
但是半平面交的时间复杂度是 O ( n log n ) O(nlog n) O(nlogn) 的,显然不符合要求。
关键是在于一开始对直线的斜率(极角)排序,只要能把排序的过程缩为 O ( n ) O(n) O(n),半平面交的时间复杂度也会降至 O ( n ) O(n) O(n) 了。
考虑到 line[lc]
、line[rc]
里面的线段都是排好序的,所以如果用归并排序的话,就能将时间复杂度降到
O
(
n
)
O(n)
O(n)。
当然,对于线段树的每一层来说,无论有多少个节点,它们都是一共维护了 n n n 条直线,所以可以看成一层若干个点都做一次半平面交的总时间复杂度是 O ( n ) O(n) O(n) 的。
那么建树的过程就是 O ( n log n ) O(nlog n) O(nlogn) 的了。
对于询问,我们先把 u → v u ightarrow v u→v 路径上的点所对应的线段树上的若干个节点用树剖找出来,然后找出当 a = m i d a=mid a=mid 时,每个线段树节点所对应的折线的值(用二分找,具体实现看代码),然后再取最大值,就是 y i − m i d × x i y_i-mid imes x_i yi−mid×xi 的最大值了。
分数规划 O ( log n ) O(log n) O(logn),树剖 O ( log n ) O(log n) O(logn),每次线段树询问 O ( log 2 n ) O(log^2 n) O(log2n)(因为线段树询问中有一个二分),所以询问的总时间复杂度为 O ( m log 4 n ) O(mlog^4 n) O(mlog4n)。
总时间复杂度是 O ( n log 4 n ) O(n log^4 n) O(nlog4n)( n n n、 m m m同级),因为 4 4 4 个 log log log 都跑不满,所以能过。
代码如下:
#include<bits/stdc++.h>
#define N 100010
#define eps 1e-5
#define lc (k<<1)
#define rc (k<<1|1)
using namespace std;
int compare(double a,double b)
{
if(fabs(a-b)<eps) return 0;
return a<b?-1:1;
}
int n,m,num;
int cnt,head[N],to[N<<1],nxt[N<<1];
int tot,size[N],fa[N],dep[N],son[N],top[N],id[N],rk[N];
double x[2][N],y[2][N];
struct Point
{
double x,y;
Point(){};
Point(double a,double b){x=a,y=b;}
};
Point operator + (Point a,Point b){return Point(a.x+b.x,a.y+b.y);}
Point operator - (Point a,Point b){return Point(a.x-b.x,a.y-b.y);}
Point operator * (Point a,double b){return Point(a.x*b,a.y*b);}
double operator * (Point a,Point b){return a.x*b.y-b.x*a.y;}
double degree(Point a)
{
return atan2(a.y,a.x);
}
struct Line
{
Point a,b;
double k,l;//k斜率,l截距
Line(){};
Line(Point aa,Point bb)
{
a=aa,b=bb;
k=(a.y-b.y)/(a.x-b.x);
l=a.y-k*a.x;
}
double get(double x)
{
return k*x+l;
}
}tmp[N],q[N],e[N];
Point intersection(Line a,Line b)//两直线求交
{
Point u=a.a-b.a;
Point v=a.b-a.a;
Point w=b.b-b.a;
double t=(w*u)/(v*w);
return a.a+v*t;
}
bool bigger(Line a,Line b)
{
if(!compare(a.k,b.k)) return compare(a.l,b.l)!=1;
return a.k<b.k;
}
bool onright(Point a,Line b)
{
return compare((b.b-b.a)*(a-b.a),0)>0?0:1;
}
struct Segment_Tree
{
vector<Line>line[N<<1];
void half(int k)//半平面交
{
int cnt1=0;
for(int i=1;i<num;i++)
{
if(!compare(tmp[i+1].k,tmp[i].k)) continue;
tmp[++cnt1]=tmp[i];
}
tmp[++cnt1]=tmp[num];
num=cnt1;
int head=0;
q[++head]=tmp[1],q[++head]=tmp[2];
for(int i=3;i<=num;i++)
{
while(head&&onright(intersection(q[head],q[head-1]),tmp[i])) head--;
q[++head]=tmp[i];
}
if(head==1)
{
line[k].push_back(q[1]);
return;
}
//下面这部分是一个小细节:由于mid的取值范围是[0,1e9],所以当交点横坐标小于0或大于1e9时,我们就不要它。
Point lp,rp;
rp=intersection(q[1],q[2]);
if(rp.x>=1e9) line[k].push_back(Line(Point(0,q[1].get(0)),Point(1e9,q[1].get(1e9))));
else if(rp.x>=0) line[k].push_back(Line(Point(0,q[1].get(0)),rp));
for(int i=2;i<=head-1;i++)
{
lp=intersection(q[i-1],q[i]),rp=intersection(q[i],q[i+1]);
if(lp.x>=0)
{
if(rp.x<=1e9) line[k].push_back(Line(lp,rp));
else if(lp.x<=1e9) line[k].push_back(Line(lp,Point(1e9,q[i].get(1e9))));
}
else if(rp.x>=0)
{
if(rp.x<=1e9) line[k].push_back(Line(Point(0,q[i].get(0)),rp));
else line[k].push_back(Line(Point(0,q[i].get(0)),Point(1e9,q[i].get(1e9))));
}
}
lp=intersection(q[head-1],q[head]);
if(lp.x<=0) line[k].push_back(Line(Point(0,q[head].get(0)),Point(1e9,q[head].get(1e9))));
else if(lp.x<=1e9) line[k].push_back(Line(lp,Point(1e9,q[head].get(1e9))));
}
void build(int k,int l,int r,bool flag)//建树
{
if(l==r)
{
int u=rk[l];
line[k].push_back(Line(Point(0,y[flag][u]),Point(1,-x[flag][u]+y[flag][u])));//随便取直线上的两个点当做线段的起点和终点(主要是我不会写只用记录直线的斜率和截距的半平面交)
return;
}
int mid=(l+r)>>1;
build(lc,l,mid,flag);
build(rc,mid+1,r,flag);
num=0;
//归并排序:
int lsize=line[lc].size(),rsize=line[rc].size(),i=0,j=0;
while(i<lsize&&j<rsize)
{
if(bigger(line[lc][i],line[rc][j])) tmp[++num]=line[lc][i],i++;
else tmp[++num]=line[rc][j],j++;
}
while(i<lsize) tmp[++num]=line[lc][i],i++;
while(j<rsize) tmp[++num]=line[rc][j],j++;
half(k);
}
double query(int k,int L,int R,int ql,int qr,double val)//询问
{
if(ql<=L&&R<=qr)
{
int l=0,r=line[k].size()-1,ans;
if(!r) return line[k][0].get(val);
while(l<=r)//二分找出mid(这里的val)在折线上的哪条线段上(当然你也可以用lower_bound)
{
int mid=(l+r)>>1;
if((compare(line[k][mid].a.x,val)!=1)&&(compare(val,line[k][mid].b.x)!=1))
{
ans=mid;
break;
}
if(compare(val,line[k][mid].a.x)==-1) r=mid-1;
else l=mid+1;
}
return line[k][ans].get(val);
}
int mid=(L+R)>>1;
double ans=-1e11;
if(ql<=mid) ans=max(ans,query(lc,L,mid,ql,qr,val));
if(qr>mid) ans=max(ans,query(rc,mid+1,R,ql,qr,val));
return ans;
}
}t1,t2;
void adde(int u,int v)
{
to[++cnt]=v;
nxt[cnt]=head[u];
head[u]=cnt;
}
void dfs(int u)
{
size[u]=1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v!=fa[u])
{
fa[v]=u;
dep[v]=dep[u]+1;
dfs(v);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
}
}
void dfs1(int u,int tp)
{
top[u]=tp;
id[u]=++tot;
rk[tot]=u;
if(son[u]) dfs1(son[u],tp);
for(int i=head[u];i;i=nxt[i])
if(to[i]!=fa[u]&&to[i]!=son[u])
dfs1(to[i],to[i]);
}
bool check(double mid,int a,int b)
{
double max1=-1e11,max2=-1e11;
while(top[a]!=top[b])
{
if(dep[top[a]]<dep[top[b]]) swap(a,b);
max1=max(max1,t1.query(1,1,n,id[top[a]],id[a],mid));
max2=max(max2,t2.query(1,1,n,id[top[a]],id[a],mid));
a=fa[top[a]];
}
if(dep[a]>dep[b]) swap(a,b);
max1=max(max1,t1.query(1,1,n,id[a],id[b],mid));
max2=max(max2,t2.query(1,1,n,id[a],id[b],mid));
return compare(max1+max2,0)!=-1;//当max1+max2>=0时,返回true
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%lf",&x[0][i]);
for(int i=1;i<=n;i++) scanf("%lf",&y[0][i]);
for(int i=1;i<=n;i++) scanf("%lf",&x[1][i]);
for(int i=1;i<=n;i++) scanf("%lf",&y[1][i]);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
adde(u,v),adde(v,u);
}
dfs(1),dfs1(1,1);
t1.build(1,1,n,0);
t2.build(1,1,n,1);
scanf("%d",&m);
while(m--)
{
int a,b;
scanf("%d%d",&a,&b);
double l=0,r=1e10,ans;
while(l<=r)//分数规划,二分
{
double mid=(l+r)/2.0;
if(check(mid,a,b)) l=mid+eps,ans=mid;
else r=mid-eps;
}
printf("%.5lf
",ans);
}
return 0;
}
代码挺长,但是思路很清晰,所以也不是太难调