解题思路
题目里有两句提示一定要看清楚,要不全买要不全卖,所以dp方程就比较好列,f[i]=max(f[j]*rate[j]*a[i])/(rate[j]*a[j]+b[j])+(f[j]*b[i])/(rate[j]*a[j]+b[j]),意义就是在从前面的某一天买入,这一天卖出,时间复杂度O(n^2),这样只有60分,,考虑优化。设在j这天a买入了x[j]股,则x[j]=(rate[j]*f[j])/(rate[j]*a[j]+b[j]),b买入了y[j]股,则y[j]=rate[j]/(rata[j]*a[j]+b[j]),那么转移方程就可以写成f[i]=x[j]*a[i]+y[j]*b[i],那么变形之后y[j]=x[j]*(a[i]/b[i])+f[i]/b[i],这不正是y=kx+b的形式,现在要求的就是用一个a[i]/b[i]斜率的直线去过x[j],y[j]这些点,使得截距最大,这正是斜率优化。但是发现这个东西只有f具有单调性,不能用单调数据结构维护,看了大佬们的博客发现可以用cdq维护。首先维护的一定是一个斜率递减的凸包,因为斜率一定为负。其次对于一条a[i]/b[i]来说,如果当前点与上一个点的斜率更小,那么向右移动可以使得截距更大,这样就可以用cdq来维护,首先按照k排序,然后cdq分治里x这一维,就可以很玄学的转移了。
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #include<cstdlib> #include<algorithm> using namespace std; const int MAXN = 100005; const double inf = 1e9; const double eps = 1e-6; int n,stk[MAXN]; double f[MAXN]; struct Query{ int id; double x,y,k,a,b,rate; }q[MAXN],tmp[MAXN]; inline bool cmp(Query A,Query B){ return A.k<B.k; } inline double slope(int A,int B){ if(q[A].x==q[B].x) return inf; return (q[A].y-q[B].y)/(q[A].x-q[B].x); } void cdq(int l,int r){ if(l==r){ f[l]=max(f[l],f[l-1]); q[l].y=f[l]/(q[l].rate*q[l].a+q[l].b); q[l].x=q[l].y*q[l].rate; return; } int mid=l+r>>1;int t1=l-1,t2=mid,top=0; for(register int i=l;i<=r;i++) { if(q[i].id<=mid) tmp[++t1]=q[i]; else tmp[++t2]=q[i]; } for(register int i=l;i<=r;i++) q[i]=tmp[i]; cdq(l,mid); for(register int i=l;i<=mid;i++){ while(top>=2 && slope(stk[top-1],stk[top])<=slope(stk[top],i)+eps) top--; stk[++top]=i; } for(register int i=mid+1;i<=r;i++){ while(top>=2 && slope(stk[top-1],stk[top])<=q[i].k+eps) top--; int j=stk[top]; f[q[i].id]=max(f[q[i].id],q[j].x*q[i].a+q[j].y*q[i].b); } cdq(mid+1,r); int L=l,R=mid+1,o=0; while(L<=mid && R<=r){ if(q[L].x<q[R].x+eps) tmp[++o]=q[L++]; else tmp[++o]=q[R++]; } while(L<=mid) tmp[++o]=q[L++]; while(R<=r) tmp[++o]=q[R++]; for(register int i=l;i<=r;i++) q[i]=tmp[i-l+1]; } int main(){ scanf("%d%lf",&n,&f[0]); for(int i=1;i<=n;i++) { scanf("%lf%lf%lf",&q[i].a,&q[i].b,&q[i].rate); q[i].k=-q[i].a/q[i].b;q[i].id=i; } sort(q+1,q+1+n,cmp);cdq(1,n); printf("%.3lf",f[n]); return 0; }