思路:
斜率优化。
设f[i]表示将前i个分组的最优值,则有转移方程式:
f[i]=max{ f[j]+a*(s[i]-s[j])^2+b*(s[i]-s[j])+c }
经过化简得到:f[i]=max{ (f[j]+a*s[j]^2-b*s[j])-2*a*s[i]*s[j] } + a*s[i]^2+b*s[i]+c
单调队列维护上凸包即可。
y[j] = (f[j]+a*s[j]^2-b*s[j])
x[j] = s[j]
min p = y[j]-2*a*s[i]*x[j] 因为a是负的 所以斜率为正 是上凸包
now.x = s[j]
now.y = y[i] = (f[i]+a*s[i]^2-b*s[i]) = {(f[j]+a*s[j]^2-b*s[j])-2*a*s[i]*s[j]+ a*s[i]^2+b*s[i]+c} + a*s[i]^2-b*s[i] = {(q[L].y)-2*a*s[i]*q[L].x} + 2*a*s[i]^2+c
答案就是 q[R].y-a*s[n]*s[n]+b*s[n]
代码一:
1 #include <bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 #define mem(a) memset(a,0,sizeof(a)) 5 #define mp(x,y) make_pair(x,y) 6 const int INF = 0x3f3f3f3f; 7 const ll INFLL = 0x3f3f3f3f3f3f3f3fLL; 8 inline ll read(){ 9 ll x=0,f=1;char ch=getchar(); 10 while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} 11 while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} 12 return x*f; 13 } 14 ////////////////////////////////////////////////////////////////////////// 15 const int maxn = 1e6+10; 16 17 struct node{ 18 ll x,y; 19 }now,q[maxn]; 20 21 int n,x; 22 ll a,b,c,s[maxn]; 23 24 ll cross(node A,node B,node C){ 25 return (B.x-A.x)*(C.y-A.y) - (C.x-A.x)*(B.y-A.y); 26 } 27 28 int main(){ 29 n=read(); 30 a=read(),b=read(),c=read(); 31 s[0] = 0; 32 for(int i=1; i<=n; i++){ 33 x=read(); 34 s[i] = s[i-1]+x; 35 } 36 37 int L=0,R=0; 38 for(int i=1; i<=n; i++){ 39 while(L<R && q[L+1].y-2*a*s[i]*q[L+1].x >= q[L].y-2*a*s[i]*q[L].x) L++; 40 while(L<R && q[L].y-2*a*s[i]*q[L].x <= q[L+1].y-2*a*s[i]*q[L+1].x) L++; 41 now.x = s[i]; 42 now.y = q[L].y-2*a*s[i]*q[L].x+2*a*s[i]*s[i]+c; 43 while(L<R && cross(q[R-1],now,q[R]) <= 0) R--; // 为什么now跑到中间去嘞?是上凸包 44 q[++R] = now; 45 46 } 47 48 cout << q[R].y-a*s[n]*s[n]+b*s[n] << endl; 49 50 return 0; 51 }
代码二:
1 #include <bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 #define mem(a) memset(a,0,sizeof(a)) 5 #define mp(x,y) make_pair(x,y) 6 const int INF = 0x3f3f3f3f; 7 const ll INFLL = 0x3f3f3f3f3f3f3f3fLL; 8 inline ll read(){ 9 ll x=0,f=1;char ch=getchar(); 10 while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} 11 while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} 12 return x*f; 13 } 14 ////////////////////////////////////////////////////////////////////////// 15 const int maxn = 1e6+10; 16 17 int n,x,q[maxn]; 18 ll a,b,c,s[maxn],f[maxn]; 19 20 ll getup(int j,int k){ 21 return f[j]-f[k]+a*(s[j]*s[j]-s[k]*s[k])+b*(s[k]-s[j]); 22 } 23 24 ll getdown(int j,int k){ 25 return (s[j]-s[k]); 26 } 27 28 int main(){ 29 n=read(); 30 a=read(),b=read(),c=read(); 31 s[0] = 0; 32 for(int i=1; i<=n; i++){ 33 x=read(); 34 s[i] = s[i-1]+x; 35 } 36 37 int L=0,R=0; 38 for(int i=1; i<=n; i++){ 39 while(L<R && getup(q[L+1],q[L]) >= s[i]*2*a*getdown(q[L+1],q[L])) L++; // a是负的 40 int j = q[L]; 41 f[i] = f[j] + a*(s[i]-s[j])*(s[i]-s[j]) + b*(s[i]-s[j]) + c; 42 while(L<R && getup(i,q[R])*getdown(q[R],q[R-1]) >= getup(q[R],q[R-1])*getdown(i,q[R])) R--; 43 q[++R] = i; 44 } 45 46 cout << f[n] << endl; 47 48 return 0; 49 }