代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 101000;
struct Node{
double x,y;
Node(){}
Node(double x,double y):x(x),y(y){}
}nod[N];
Node operator + (const Node &a,const Node &b){
return Node(a.x+b.x,a.y+b.y);
}
Node operator - (const Node &a,const Node &b){
return Node(a.x-b.x,a.y-b.y);
}
double dot(const Node &a,const Node &b){
return a.x*b.x+a.y*b.y;
}
double cross(const Node &a,const Node &b){
return a.x*b.y-a.y*b.x;
}
ll ans,n,m,a[N],b[N],pre[N],num[N];
bool check(int,int);
bool check(Node&,Node&,Node&);
int main(){
scanf("%lld%lld",&n,&m);
for (int i = 1;i <= n;i++) scanf("%lld",&a[i]);
for (int i = 1;i <= m;i++) scanf("%lld",&b[i]);
for (int i = 1;i <= n;i++) pre[i] = pre[i-1]+a[i];
int t = 1,w = 2;
nod[1] = Node(pre[1],pre[1]-a[1]);num[1] = 1;
nod[2] = Node(pre[2],pre[2]-a[2]);num[2] = 2;
for (int i = 3;i <= n;i++){
Node c = Node(pre[i],pre[i]-a[i]);
while (t < w && check(nod[w-1],nod[w],c)) w--;
nod[++w] = c;num[w] = i;
}
for (int i = 2;i <= m;i++){
int l = t,r = w,mid = l + r >> 1;
while (l < r){
if (check(mid,i)) l = mid + 1;else r = mid;
mid = l + r >> 1;
}
ans += pre[num[l]]*b[i-1]-pre[num[l]-1]*b[i];
}
ans += pre[n]*b[m];
printf("%lld
",ans);
return 0;
}
bool check(int x,int y){
// return (nod[x].y-nod[x+1].y)/(nod[x].x-nod[x+1].x)<b[y-1]/b[y];
return (nod[x+1].y-nod[x].y)*b[y]<(nod[x+1].x-nod[x].x)*b[y-1];
}
bool check(Node& x,Node& y,Node& z){
return cross(y-x,z-x) <= 0;
}