有关算法:
二分答案;
但是你只二分答案是不够的,因为你check会炸,所以还要考虑前缀和;
首先假装我们的check已经写好了,main函数:
int main() { n=read(); m=read(); S=read(); ll maxn=0; for(ll i=1; i<=n; i++) w[i]=read(),v[i]=read(),maxn=max(maxn,w[i]); for(ll i=1; i<=m; i++) _l[i]=read(),_r[i]=read(); ll l=0,r=maxn,ans1=1e17+7,ans2=1e17+7; while(l<=r) { ll mid=l+r>>1; ll ls=check(mid); if(ls<S) { ans1=min(ans1,S-ls); r=mid-1; } if(ls==S) { printf("0"); return 0; } if(ls>S) { ans2=min(ans2,ls-S); l=mid+1; } } printf("%lld",min(ans1,ans2)); return 0; }
输入没有什么可以说的,然后是二分答案,二分答案的话,从0~最大的wi;
二分的标准套路,先计算mid,用check函数判应该往左区间二分还是右区间二分,比较不好想的就是怎么判断往左区间还是右区间二分,这里可以想到,当我们求出的中间值的Y之后,如果发现它比S小,那么如果要找更小的差距,应该让Y的值更大才有可能,那么如果让Y的值更大,我们应该选入更多的矿产,所以我们应该使二分的答案减小,因此r=mid-1;然后这里记录两个答案,ans1,ans2,分别记录的是求得的值小于S的最小差,求得值大于S的最小差(显然等于S时就直接输出不需要再继续循环了);
然后如果没有找到使差为0的W,我们就输出ans1和ans2中较小的一个;
好了讲完了;
并没有讲完啊,我们还莫得讲check函数;
最简单的方法,暴力扫描:
ll check(ll x) { ll cnt=0,sum=0,Y=0; for(ll i=1; i<=m; i++) { cnt=0; sum=0; for(ll j=_l[i]; j<=_r[i]; j++) { if(w[j]>=x) cnt++,sum+=v[j]; } Y+=(cnt*sum); } return Y; }
然后你会发现你T成这样:
然后经过大佬ych的提醒,我们想到了前缀和:
ll check(ll x) { ll Y=0; for(int i=1;i<=n;i++){ if(w[i]>=x) sum[i]=sum[i-1]+v[i],cnt[i]=cnt[i-1]+1; else sum[i]=sum[i-1],cnt[i]=cnt[i-1]; } for(int i=1;i<=m;i++){ Y+=_abs(sum[_r[i]]-sum[_l[i]-1])*_abs(cnt[_r[i]]-cnt[_l[i]-1]); } return Y; }
sum[i]表示1~i所有点中所有wi>=二分答案的的矿产的v之和,cnt[i]表示1~i以内所有点中所有wi>=二分答案的矿产个数;
然后处理应该很好理解,不再赘述;
然后再一次for循环,对于每个区间,利用维护的前缀和计算sum*cnt,然后相加即为答案;
#include<bits/stdc++.h> #define ll long long using namespace std; inline ll read() { ll ans=0; char last=' ',ch=getchar(); while(ch<'0'||ch>'9') last=ch,ch=getchar(); while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar(); if(last=='-') ans=-ans; return ans; } ll n,m,w[200001],v[200001],S,_l[200001],_r[200001],sum[200001],cnt[200001]; ll _abs(ll x) { if(x<0) x=-x; return x; } ll check(ll x) { ll Y=0; for(int i=1;i<=n;i++){ if(w[i]>=x) sum[i]=sum[i-1]+v[i],cnt[i]=cnt[i-1]+1; else sum[i]=sum[i-1],cnt[i]=cnt[i-1]; } for(int i=1;i<=m;i++){ Y+=_abs(sum[_r[i]]-sum[_l[i]-1])*_abs(cnt[_r[i]]-cnt[_l[i]-1]); } return Y; } int main() { n=read(); m=read(); S=read(); ll maxn=0; for(ll i=1; i<=n; i++) w[i]=read(),v[i]=read(),maxn=max(maxn,w[i]); for(ll i=1; i<=m; i++) _l[i]=read(),_r[i]=read(); ll l=0,r=maxn,ans1=1e17+7,ans2=1e17+7; while(l<=r) { ll mid=l+r>>1; ll ls=check(mid); if(ls<S) { ans1=min(S-ls,ans1); r=mid-1; } if(ls==S) { printf("0"); return 0; } if(ls>S) { ans2=min(ans2,ls-S); l=mid+1; } } printf("%lld",min(ans1,ans2)); return 0; }
end-