斜率优化训练记录
前言
斜率优化一般用于优化dp的转移,借着训练斜率优化的相关问题来提升一些DP思维。选择老学长留下的专题场来练手,由于该场题数较多,以及个人不太愿意长时间进行单一专题训练,因此开此文来记录断续的训练结果和心得。
记录
0x01
由一道简单入门题玩具装箱开头,题意和思路比较简单就不讲了。
代码
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<endl
#define sz(x) int(x.size())
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define All(x) x.begin(),x.end()
using namespace std;
typedef long long ll;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=1e5+10,INF=0x3f3f3f3f,mod=1e9+7;
ll f[maxn],sum[maxn],a[maxn],q[maxn],h,t;
inline double K(int i,int j)
{
double dy=a[j]*a[j]+f[j]-a[i]*a[i]-f[i],dx=a[j]-a[i];
return dy/dx;
}
int main()
{
int n,l;
cin>>n>>l;
for (int i=1;i<=n;++i)
{
int x;
scanf("%d",&x);
sum[i]=sum[i-1]+x;
a[i]=i+1+sum[i];
}
a[0]=1;
h=t=1;
for (int i=1;i<=n;++i)
{
while (h<t&&K(q[h],q[h+1])<2*(sum[i]+i-l))
h++;
int j=q[h];
ll tmp=i-j-1+sum[i]-sum[j]-l;
f[i]=f[j]+tmp*tmp;
while (h<t&&K(q[t-1],q[t])>K(q[t-1],i))
t--;
q[++t]=i;
}
cout<<f[n];
return 0;
}
0x02
小A与最大字段和还是入门题。题意见原题面,比较简单。思路的话,维护一个普通前缀和与一个梯形前缀和,然后与上题一样通过变形式子写成直线截距式。唯一的不同是,该题Ai可能是负数,因此直线的斜率不能保证单调变化,因此选取最优点时需要二分队列找到首个往后比直线斜率小的点。(上一题由于直线斜率单调增加,因此每次选最优点只要把队首比直线斜率小的点都出队即可)。然后这题要最大值,所以入队时要维护一个上凸壳。
代码
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<endl
#define sz(x) int(x.size())
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define All(x) x.begin(),x.end()
using namespace std;
typedef long long ll;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=2e5+10,INF=0x3f3f3f3f,mod=1e9+7;
ll s1[maxn],s2[maxn],q[maxn],h,t;
inline double K(int i,int j)
{
double dy=j*s1[j]-s2[j]-i*s1[i]+s2[i],dx=j-i;
return dy/dx;
}
int find(double k)
{
int l=h,r=t,ans;
while (l<=r)
{
int m=(l+r)>>1;
if (m==t)
{
ans=t;
break;
}
if (K(q[m],q[m+1])<=k)
ans=m,r=m-1;
else
l=m+1;
}
return ans;
}
int main()
{
int n;
cin>>n;
for (int i=1;i<=n;++i)
{
int x;
scanf("%d",&x);
s1[i]=s1[i-1]+x;
s2[i]=s2[i-1]+i*x;
}
ll ans=-1e18;
h=t=1;
for (int i=1;i<=n;++i)
{
int j=q[find(s1[i])];
ans=max(ans,s2[i]-s2[j]-j*(s1[i]-s1[j]));
while (h<t&&K(i,q[t-1])>K(q[t],q[t-1]))
--t;
q[++t]=i;
}
cout<<ans;
return 0;
}
0x03
HDU2993这题思路不是很难,但是不知道为什么卡读入,相当恶心,比较low的IO优化还过不去,我T了十几发之后用了学长的fread读入的板子才过的。因此极其不建议大家做,脑内AC就可以了。
题意是给个长为n的数列,找一个长度不小于k且平均值最大的子段。换句话说就是找所有点(i,sum[i])中斜率最大的两个点的斜率。
不难想到,我们应该维护一个下凸壳(因为其上方的点肯定无法与之后的点构成更优的解),然后一般可以二分找最优点,但是由这题的性质可以发现,因为sum[i]是递增的,因此每次找到更优的点时,其前面的点可以直接舍弃(它们不可能与后面新加入的点构成更大的斜率了)。因此复杂度可以做到O(N)。
代码(再次强调,不建议做)
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<endl
#define sz(x) int(x.size())
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define All(x) x.begin(),x.end()
using namespace std;
typedef long long ll;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=1e5+10,INF=0x3f3f3f3f,mod=1e9+7;
struct FastIO {
static const int S = 1310720;
int wpos;
char wbuf[S];
FastIO() : wpos(0) { }
inline int xchar() {
static char buf[S];
static int len = 0, pos = 0;
if (pos == len) pos = 0, len = fread(buf, 1, S, stdin);
if (pos == len) return -1;
return buf[pos++];
}
inline int xint() {
int c = xchar(), x = 0, s = 1;
if (c==-1)
return -1;
while (c <= 32) c = xchar();
if (c == '-') s = -1, c = xchar();
for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0';
return x * s;
}
~FastIO() {
if (wpos) fwrite(wbuf, 1, wpos, stdout), wpos = 0;
}
} io;
ll sum[maxn];
int q[maxn],h,t;
inline double K(int i,int j)
{
return double(sum[i]-sum[j])/(i-j);
}
int main()
{
int n,k;
while (n=io.xint(),k=io.xint())
{
if (n==-1)
break;
for (int i=1;i<=n;++i)
{
int x=io.xint();
sum[i]=sum[i-1]+x;
}
h=0;
t=-1;
double ans=-1;
for (int i=k;i<=n;++i)
{
while (h<t&&K(q[t],q[t-1])>K(i-k,q[t-1]))
--t;
q[++t]=i-k;
while (h<t&&K(q[h],i)<K(q[h+1],i))
++h;
ans=max(ans,K(q[h],i));
}
printf("%.2lf
",ans);
}
return 0;
}
0x04
HDU 3045,注意一下转移的合法性,其余的就是常规的斜率优化DP。特别注意斜率的比较上不要写错(手贱做差时取了绝对值,WA了十几发),为了避免精度误差,建议写成乘法形式来比较。
代码
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"
"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=5e5+10,mod=1e9+7,INF=0x3f3f3f3f;
ll a[maxn],sum[maxn],f[maxn];
int q[maxn],h,t;
ll dy(int i,int j)
{
return f[i]-sum[i]+i*a[i+1]-f[j]+sum[j]-j*a[j+1];
}
ll dx(int i,int j)
{
return a[i+1]-a[j+1];
}
ll getf(int i,int j)
{
return f[i]+sum[j]-sum[i]-a[i+1]*(j-i);
}
int main()
{
int n,k;
while (scanf("%d%d",&n,&k)!=EOF)
{
for (int i=1;i<=n;++i)
scanf("%I64d",&a[i]);
sort(a+1,a+1+n);
for (int i=1;i<=n;++i)
sum[i]=sum[i-1]+a[i];
h=t=0;
for (int i=k;i<=n;++i)
{
int j=i-k;
if (j>=k)
{
while (h<t&&dy(j,q[t-1])*dx(q[t],q[t-1])<=dy(q[t],q[t-1])*dx(j,q[t-1]))
--t;
q[++t]=j;
}
while (h<t&&dy(q[h+1],q[h])<=i*dx(q[h+1],q[h]))
++h;
f[i]=getf(q[h],i);
}
printf("%I64d
",f[n]);
}
return 0;
}
0x05
POJ 1180,挺好的一题。题意比较繁琐,建议直接看原题面。
这题用斜率优化的部分也很常规,比较有技巧的是如何把枚举分组数的这一维优化掉,使得复杂度降到O(N)。可以发现,要知道当前这个组的结束时间,与之前分了几个组有关,这样的话转移时就必须枚举了。但是换个角度,我们可以考虑每个分组对后面分组代价的影响,提前计算对当前组造成的s对全局答案的贡献,这样转移时就不需要考虑前面分组对当前的影响了(因为已经在前面计算过了)。因此转移方程可以写成f(j)=min{f(i)+(tsum[j]+s)*(fsum[j]-fsum[i])+s*(fsum[n]-fsum[j])}(0<=i<j),有了这个转移式,剩下的就是通过斜率优化,把i的枚举优化掉,达到线性复杂度。
#include<iostream>
#include<cstdio>
#include<algorithm>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"
"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
const int maxn=1e5+10,mod=1e9+7,INF=0x3f3f3f3f;
ll fsum[maxn],tsum[maxn],dp[maxn];
int q[maxn],h,t,s,n;
inline ll dy(int i,int j)
{
return dp[i]-dp[j];
}
inline ll dx(int i,int j)
{
return fsum[i]-fsum[j];
}
inline ll getv(int i,int j)
{
return dp[i]+(tsum[j]+s)*(fsum[j]-fsum[i])+s*(fsum[n]-fsum[j]);
}
int main()
{
while (cin>>n)
{
cin>>s;
for (int i=1;i<=n;++i)
{
scanf("%lld%lld",&tsum[i],&fsum [i]);
tsum[i]+=tsum[i-1],fsum[i]+=fsum[i-1];
}
t=h=0;
for (int j=1;j<=n;++j)
{
while (h<t&&dy(q[h+1],q[h])<=(tsum[j]+s)*dx(q[h+1],q[h]))
++h;
dp[j]=getv(q[h],j);
while (h<t&&dy(j,q[t-1])*dx(q[t],q[t-1])<=dy(q[t],q[t-1])*dx(j,q[t-1]))
--t;
q[++t]=j;
}
cout<<dp[n]<<endl;
}
return 0;
}
0x06
HDU 3480,相比于第一题玩具装箱,多了分组数的限制,其他是一模一样的。出于空间的限制,我们先枚举分组数,这样每次队列就可以清空复用。注意分组数为k时的答案都是由分组数k-1时的答案转移而来,注意边界点,细节见代码
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"
"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=1e4+10,mod=1e9+7,INF=0x3f3f3f3f;
ll a[maxn],f[maxn>>1][maxn];
int q[maxn],h,t;
inline ll dy(int i,int j,int k)
{
return f[k][i]+a[i+1]*a[i+1]-f[k][j]-a[j+1]*a[j+1];
}
inline ll dx(int i,int j)
{
return a[i+1]-a[j+1];
}
inline ll getf(int i,int j,int k)
{
return f[k-1][i]+(a[j]-a[i+1])*(a[j]-a[i+1]);
}
int main()
{
int T;
cin>>T;
for (int cas=1;cas<=T;++cas)
{
int n,m;
scanf("%d%d",&n,&m);
for (int i=1;i<=n;++i)
scanf("%lld",&a[i]);
sort(a+1,a+1+n);
for (int j=1;j<=n;++j)
f[1][j]=(a[j]-a[1])*(a[j]-a[1]);
for (int k=2;k<=m;++k)
{
h=t=0;
q[0]=k-1;
for (int j=k;j<=n;++j)
{
while (h<t&&dy(q[h+1],q[h],k-1)<=2*a[j]*dx(q[h+1],q[h]))
++h;
f[k][j]=getf(q[h],j,k);
while (h<t&&dy(j,q[t-1],k-1)*dx(q[t],q[t-1])<=dy(q[t],q[t-1],k-1)*dx(j,q[t-1]))
--t;
q[++t]=j;
}
}
printf("Case %d: %lld
",cas,m>n?0ll:f[m][n]);
}
return 0;
}
0x07
HDU 2829,与上题一个类型,这题定义一个区间的价值为其中任意两个数乘积的和。在区间价值的表达式上想了比较久,区间[i,j]的价值可以表示为C[j]-C[i]-sum[i]*(sum[j]-sum[i]),C[i]表示前缀i的价值,想出来之后就是跟上题一样的处理,没什么坑点,注意边界就行。
代码
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"
"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=1e3+10,mod=1e9+7,INF=0x3f3f3f3f;
ll f[maxn][maxn],a[maxn],sum[maxn],c[maxn];
int q[maxn],h,t;
inline ll dy(int i,int j,int k)
{
return f[k][i]+sum[i]*sum[i]-c[i]-f[k][j]-sum[j]*sum[j]+c[j];
}
inline ll dx(int i,int j)
{
return sum[i]-sum[j];
}
inline ll getf(int i,int j,int k)
{
return f[k-1][i]+c[j]-c[i]-(sum[j]-sum[i])*sum[i];
}
int main()
{
int n,m;
while (scanf("%d%d",&n,&m)&&n)
{
++m;
for (int i=1;i<=n;++i)
{
scanf("%lld",&a[i]);
sum[i]=sum[i-1]+a[i];
}
for (int i=2;i<=n;++i)
c[i]=c[i-1]+sum[i-1]*a[i];
for (int j=1;j<=n;++j)
f[1][j]=c[j];
for (int k=2;k<=m;++k)
{
h=t=0;
q[0]=k-1;
for (int j=k;j<=n;++j)
{
while (h<t&&dy(q[h+1],q[h],k-1)<=sum[j]*dx(q[h+1],q[h]))
++h;
f[k][j]=getf(q[h],j,k);
while (h<t&&dy(j,q[t-1],k-1)*dx(q[t],q[t-1])<=dy(q[t],q[t-1],k-1)*dx(j,q[t-1]))
--t;
q[++t]=j;
}
}
printf("%lld
",f[m][n]);
}
return 0;
}
0x08
HDU 3507,简化版的玩具装箱,没有长度限制也没有分组数限制,入门级别
代码
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"
"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=5e5+10,mod=1e9+7,INF=0x3f3f3f3f;
ll f[maxn],a[maxn],sum[maxn];
int q[maxn],h,t,M;
inline ll dy(int i,int j)
{
return f[i]+sum[i]*sum[i]-f[j]-sum[j]*sum[j];
}
inline ll dx(int i,int j)
{
return sum[i]-sum[j];
}
inline ll getf(int i,int j)
{
return f[i]+(sum[j]-sum[i])*(sum[j]-sum[i])+M;
}
int main()
{
int n;
while (scanf("%d%d",&n,&M)!=EOF)
{
for (int i=1;i<=n;++i)
{
scanf("%lld",&a[i]);
sum[i]=sum[i-1]+a[i];
}
h=t=0;
for (int i=1;i<=n;++i)
{
while (h<t&&dy(q[h+1],q[h])<=2*sum[i]*dx(q[h+1],q[h]))
++h;
f[i]=getf(q[h],i);
while (h<t&&dy(i,q[t-1])*dx(q[t],q[t-1])<=dy(q[t],q[t-1])*dx(i,q[t-1]))
--t;
q[++t]=i;
}
printf("%lld
",f[n]);
}
return 0;
}
0x09
POJ 3709,有长度限制,无分组数限制,常规解法
#include<iostream>
#include<cstdio>
#include<algorithm>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"
"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
const int maxn=5e5+10,mod=1e9+7,INF=0x3f3f3f3f;
ll f[maxn],a[maxn],sum[maxn];
int q[maxn],h,t,M;
inline ll dy(int i,int j)
{
return f[i]-sum[i]+i*a[i+1]-f[j]+sum[j]-j*a[j+1];
}
inline ll dx(int i,int j)
{
return a[i+1]-a[j+1];
}
inline ll getf(int i,int j)
{
return f[i]+sum[j]-sum[i]-(j-i)*a[i+1];
}
int main()
{
int T;
cin>>T;
while (T--)
{
int n,k;
scanf("%d%d",&n,&k);
for (int i=1;i<=n;++i)
{
scanf("%lld",&a[i]);
sum[i]=sum[i-1]+a[i];
}
h=t=0;
for (int i=k;i<=n;++i)
{
int j=i-k;
if (j>=k)
{
while (h<t&&dy(j,q[t-1])*dx(q[t],q[t-1])<=dy(q[t],q[t-1])*dx(j,q[t-1]))
--t;
q[++t]=j;
}
while (h<t&&dy(q[h+1],q[h])<=i*dx(q[h+1],q[h]))
++h;
f[i]=getf(q[h],i);
}
printf("%lld
",f[n]);
}
return 0;
}
0x0A
HDU 3669,每个人有高h和宽w,人可以过门仅当高和宽均不大于门的高H和宽W。要造不超过k道门,每道门确定一个W和H,每道门的造价是W*H,使得所有人都能通过至少一道门,求最少造价。为了方便计算,我们进行一定的预处理,我们按人的其中一个属性升序排列,得到的序列再按另一个属性降序排列(即剔除那些h和w均小于某人的人,显然他们不会影响答案),这时得到的序列,满足第一维递增,第二维递减。显然进同一道门的人在序列里肯定是连续的。跑一个DP来进行分组,斜率优化一下转移就做完了。
预处理细节见代码
#include<bits/stdc++.h>
#define fi first
#define se second
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<endl
using namespace std;
typedef long long ll;
typedef pair<ll, ll> P;
const int maxn=5e4+10;
P a[maxn];
int q[maxn],h,t;
ll f[200][maxn];
inline ll dy(int i,int j,int k)
{
return f[k][i]-f[k][j];
}
inline ll dx(int i,int j)
{
return -a[i+1].se+a[j+1].se;
}
inline ll getf(int i,int j,int k)
{
return f[k-1][i]+a[i+1].se*a[j].fi;
}
int main()
{
int n,m;
while (scanf("%d%d",&n,&m)!=EOF)
{
for (int i=1;i<=n;++i)
scanf("%lld%lld",&a[i].fi,&a[i].se);
sort(a+1,a+1+n);
int cnt=0;
for (int i=1;i<=n;++i)
{
while (cnt&&a[i].se>=a[cnt].se)
--cnt;
a[++cnt]=a[i];
}
m=min(m,cnt);
for (int i=1;i<=cnt;++i)
f[1][i]=a[i].fi*a[1].se;
for (int k=2;k<=m;++k)
{
h=t=0;
q[0]=k-1;
for (int j=k;j<=cnt;++j)
{
while (h<t&&dy(q[h+1],q[h],k-1)<=a[j].fi*dx(q[h+1],q[h]))
++h;
f[k][j]=getf(q[h],j,k);
while (h<t&&dy(j,q[t-1],k-1)*dx(q[t],q[t-1])<=dy(q[t],q[t-1],k-1)*dx(j,q[t-1]))
--t;
q[++t]=j;
}
}
ll ans=1e18;
for (int i=1;i<=m;++i)
ans=min(ans,f[i][cnt]);
printf("%lld
",ans);
}
}
0x0B
311B - Cats Transport,挺好的一题,要先处理出接每只猫的最早出发时间ai,对这个时间排序,贪心思想可知每个饲养员肯定要负责接连续一段的猫。于是对这个数组a进行dp分组,斜率优化。
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"
"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=1e5+10,mod=1e9+7,INF=0x3f3f3f3f;
ll a[maxn],f[105][maxn],sumd[maxn],suma[maxn];
int q[maxn],h,t;
ll dy(int i,int j,int k)
{
return f[k][i]+suma[i]-f[k][j]-suma[j];
}
ll dx(int i,int j)
{
return i-j;
}
ll getf(int i,int j,int k)
{
return f[k-1][i]+a[j]*(j-i)-suma[j]+suma[i];
}
int main()
{
int n,m,p;
cin>>n>>m>>p;
for (int i=2;i<=n;++i)
{
scanf("%I64d",&sumd[i]);
sumd[i]+=sumd[i-1];
}
for (int i=1;i<=m;++i)
{
ll h,t;
scanf("%I64d%I64d",&h,&t);
a[i]=t-sumd[h];
}
sort(a+1,a+1+m);
for (int i=1;i<=m;++i)
{
suma[i]=suma[i-1]+a[i];
f[1][i]=i*a[i]-suma[i];
}
for (int k=2;k<=p;++k)
{
h=t=0;
q[0]=k-1;
for (int j=k;j<=m;++j)
{
while (h<t&&dy(q[h+1],q[h],k-1)<=a[j]*dx(q[h+1],q[h]))
++h;
f[k][j]=getf(q[h],j,k);
while (h<t&&dy(j,q[t-1],k-1)*dx(q[t],q[t-1])<=dy(q[t],q[t-1],k-1)*dx(j,q[t-1]))
--t;
q[++t]=j;
}
}
cout<<f[p][m];
return 0;
}
0x0C
货币兑换Cash ,由贪心的思想可知,每次买卖一定会把钱花完或把金券卖完。因此f(i)表示到第i天手里最多有多少钱,f(i)=max(A(j)*a[i]+B(j)*b[i]),其中A(j)表示第j天把钱全部花完可以获得的A金券数量,B(j)同理。稍作变形得-B(j)=a[i]/b[i]*A(j)-f[i]/b[i]。可以发现横坐标A(j)和斜率a[i]/b[i]都不单调。
这种情况可以用平衡树维护动态凸包或者cdq分治来解决。我这里选用了CDQ来解决这个问题。
大概思想是,对于每一个f(i),它的决策点一定在它左边,换句话说每个点只会影响它右边的点。对于求解区间[l,r]内的f值,可以先递归求解其左区间,完成后左区间的f则处理完毕,然后对于左区间的点重新排序维护一个凸包,对于右区间的点可以二分这个凸包找决策点进行更新,然后再递归求解右区间内部的影响,完成后则右区间的f值也处理完毕。此时这个大区间的f值处理完毕。
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<endl
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define fi first
#define se second
#define mp make_pair
#define pb push_back
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> pq;
const int maxn=1e5+10,mod=1e9+7,INF=0x3f3f3f3f;
const double eps=1e-8;
double a[maxn],b[maxn],r[maxn],f[maxn];
int tmp[maxn],q[maxn],t;
inline int sign(double x){
if (fabs(x)<eps)
return 0;
return x<0?-1:1;
}
inline double A(int i){
return f[i]*r[i]/(r[i]*a[i]+b[i]);
}
inline double B(int i){
return -f[i]/(r[i]*a[i]+b[i]);
}
inline double dy(int i,int j){
return B(i)-B(j);
}
inline double dx(int i,int j){
return A(i)-A(j);
}
bool cmp(int x,int y){
return sign(A(x)-A(y))?sign(A(x)-A(y))<0:sign(B(x)-B(y))<0;
}
void push(int i)
{
while (t>1&&sign(dy(i,q[t-1])*dx(q[t],q[t-1])-dy(q[t],q[t-1])*dx(i,q[t-1]))<=0)
--t;
q[++t]=i;
}
double find(int i)
{
double k=a[i]/b[i];
int l=1,r=t,p;
while (l<=r)
{
int m=(l+r)>>1;
if (m==t)
{
p=t;
break;
}
if (sign(dy(q[m+1],q[m])-k*dx(q[m+1],q[m]))>=0)
p=m,r=m-1;
else
l=m+1;
}
return A(q[p])*a[i]-B(q[p])*b[i];
}
void cdq(int l,int r)
{
if (l==r)
{
f[l]=max(f[l],f[l-1]);
return;
}
int m=(l+r)>>1;
cdq(l,m);
int tn=0;
for (int i=l;i<=m;++i)
tmp[tn++]=i;
sort(tmp,tmp+tn,cmp);
t=0;
for (int i=0;i<tn;++i)
push(tmp[i]);
for (int i=m+1;i<=r;++i)
f[i]=max(f[i],find(i));
cdq(m+1,r);
}
int main()
{
int n;
double s;
cin>>n>>s;
for (int i=1;i<=n;++i)
scanf("%lf%lf%lf",&a[i],&b[i],&r[i]);
f[1]=s;
cdq(1,n);
printf("%.3lf",f[n]);
return 0;
}