有 (n) 个桥墩,高 (h_i) 重 (w_i)。连接 (i) 和 (j) 消耗代价 ((h_i-h_j)^2),用不到的桥墩被拆除,代价为 (w_i)。求使 (1) 与 (n) 联通的最小代价。
数据范围:(2le nle 10^5),(0le h_i,|w_i|le 10^6)。
非常经典的李超线段树维护 ( exttt{dp}) 的题目,小蒟蒻来分享一下。
很明显 (w_i) 是大片大片消耗的,所以记 (s_i=sum_{j=1}^i w_j)。
令 (f_i) 表示连接到第 (i) 个桥墩的最小代价。可以野蛮推式:
[egin{split}
f_i=&min{f_j+(h_i-h_j)^2+s_{i-1}-s_j}\
=&min{f_j+h_i^2-2h_ih_j+h_j^2+s_{i-1}-s_j}\
=&h_i^2+s_{i-1}+min{f_j-2h_ih_j+h_j^2-s_j}\
end{split}\
]
这貌似是个斜率优化式子,但蒟蒻不管,用李超线段树怎么做呢?
考虑李超线段树的作用:多条线段(直线),求单点最值。
发现这个 (j) 有很多,而 (i) 就只有当前一个:所以可以 (i) 对应单点,(j) 对应线。换句话说,可以把每个 (f_i) 求出来后添加一条直线。
[ extrm{let line } g(x)=(-2h_j)x+(f_j+h_j^2-s_j)\
extrm{let } x=h_i extrm{ to get }min{f_j-2h_ih_j+h_j^2-s_j} extrm{.}\
]
这题有几个坑,本来是应该由你来快乐地调试的,但是既然写了题解,蒟蒻就放出来了:
- 因为要计算 (h_j^2),所以要开 ( exttt{long long}) 或用 (1ll) 乘之。
- 这个李超线段树是权值线段树,下标要开 (10^6) 个,节点个数要开 (4cdot 10^6) 个。
这个做法貌似有点辜负了这题的难度,但是蒟蒻只会这么做。蒟蒻讲不清楚,还是放个蒻蒻的代码吧:
#include <bits/stdc++.h>
using namespace std;
//Start
#define lng long long
#define db double
#define mk make_pair
#define pb push_back
#define fi first
#define se second
#define rz resize
const int inf=0x3f3f3f3f;
const lng INF=0x3f3f3f3f3f3f3f3f;
//Data
const int N=1e5,M=1e6;
int n,h[N+7];
lng w[N+7],f[N+7];
//Lichaotree
typedef pair<lng,lng> line;
lng g(line&li,int x){return li.fi*x+li.se;}
int inter(line&la,line&lb){return db(lb.se-la.se)/(la.fi-lb.fi);}
line v[(M<<2)+7];
void add(line li,int k=1,int l=0,int r=M){
int mid((l+r)>>1);
lng ly1=g(li,l),ry1=g(li,r),ly=g(v[k],l),ry=g(v[k],r);
if(ly1>=ly&&ry1>=ry);
else if(ly1<=ly&&ry1<=ry) v[k]=li;
else {
int in=inter(li,v[k]);
if(ly1<=ly){
if(in<=mid) add(li,k<<1,l,mid);
else add(v[k],k<<1|1,mid+1,r),v[k]=li;
} else {
if(in>mid) add(li,k<<1|1,mid+1,r);
else add(v[k],k<<1,l,mid),v[k]=li;
}
}
}
lng get(int x,int k=1,int l=0,int r=M){
lng res(g(v[k],x));
if(l==r) return res;
int mid((l+r)>>1);
if(mid>=x) res=min(res,get(x,k<<1,l,mid));
else res=min(res,get(x,k<<1|1,mid+1,r));
return res;
}
//Main
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&h[i]);
for(int i=1;i<=n;i++) scanf("%lld",&w[i]),w[i]+=w[i-1];
fill(v+1,v+(M<<2)+1,mk(0,INF));
f[1]=0,add(mk(-2ll*h[1],1ll*h[1]*h[1]-w[1]));
for(int i=2;i<=n;i++){
f[i]=1ll*h[i]*h[i]+w[i-1]+get(h[i]);
add(mk(-2ll*h[i],f[i]+1ll*h[i]*h[i]-w[i]));
}
printf("%lld
",f[n]);
return 0;
}
祝大家学习愉快!