题目链接:CF439D Devu and his Brother
太困肝了一晚上也没肝出来,被大佬虐爆了。
考虑这个结论:如果需要操作才能实现要求,那这个数一定是(a),(b)中的数。
为什么呢,你可以尝试设此数为(k),写出其函数解析式,则一定是一次函数,单调性明显,一定出现在两端点上。
考虑对答案有贡献:
显然是(a)数组中(<k)和(b)数组中(>k)的数才有贡献,这时我们直接枚举每一个数,然后二分查找边界值,当然要处理出前缀和,为什么呢?
(ps):下面的数组都是从小到大排序意义上的:
比如说我们要把(max_a)和(min_b)都趋近于(k),这时(a)数组中前(x)个数(<k),(b)数组中(b_y)到(b_m>k),那么我们记(sum_{a/b,i})为前缀和,则贡献为:
[kx-sum_{a,x}+sum_{b,m}-sum_{b,y-1}-k(m-y+1)
]
记得开(long;long),这样复杂度就是(O((n+m)log(n+m))),可以通过本题。
令外讨论区还有三分的做法,笔者不会,所以有兴趣就去讨论区看吧(咕咕~
(Code):
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int MAXN=100005;
typedef long long ll;
const ll inf=1ll<<61;
int a[MAXN],b[MAXN];
ll suma[MAXN],sumb[MAXN],ans=inf;
int mina=2147483647,maxb=0;
int n,m;
int mm;
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]),mina=min(mina,a[i]);
for(int i=1;i<=m;i++) scanf("%d",&b[i]),maxb=max(maxb,b[i]);
if(mina>=maxb){
printf("%d
",0);
return 0;
}
sort(a+1,a+n+1),sort(b+1,b+m+1);
for(int i=1;i<=n;i++) suma[i]=suma[i-1]+(ll)a[i];
for(int i=1;i<=m;i++) sumb[i]=sumb[i-1]+(ll)b[i];
for(int i=1;i<=n;i++){
if(a[i]>=b[m]) break;
int l=1,r=m,mid;
while(1){
mid=(l+r)>>1;
if(l>=r) break;
if(b[mid]<=a[i]) l=mid+1;
else r=mid;
}
ll now=(ll)a[i]*(ll)(i-1)-suma[i-1];
if(mid>0) now=now+sumb[m]-sumb[mid-1]-(ll)a[i]*(ll)(m-mid+1);
if(now<ans) ans=now,mm=a[i];
}
for(int i=1;i<=m;i++){
if(b[i]<a[1]) continue;
int l=1,r=n,mid;
while(1){
mid=(l+r)>>1;
if(l>=r) break;
if(a[mid]<=b[i]) l=mid+1;
else r=mid;
}
mid-=1;
if(b[i]>=a[n]) mid=n;
ll now=sumb[m]-sumb[i]-(ll)b[i]*(ll)(m-i);
now=now+(ll)b[i]*(ll)mid-suma[mid];
if(ans>now) ans=now,mm=b[i];
}
printf("%lld
",ans);
return 0;
}
大佬xxx:你好傻啊,这么麻烦,其实核心代码只有一行,这么(naive)的贪心都不会。
我:???
索性盗了一份代码,谁会证明呢,大佬可以在讨论群秀一秀?
(Code)(神仙贪心):
#include <bits/stdc++.h>
#define gc() getchar()
#define LL long long
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
#define _rep(i, a, b) for (int i = (a); i >= (b); --i)
using namespace std;
const int N = 1e5 + 5;
int n, m, a[N], b[N];
inline int read() {
int x = 0;
char ch = gc();
while (!isdigit(ch)) ch = gc();
while (isdigit(ch)) x = x * 10 + ch - '0', ch = gc();
return x;
}
inline bool cmp(int x , int y) {
return x > y;
}
int main() {
n = read(), m = read();
rep(i, 1, n) a[i] = read();
rep(i, 1, m) b[i] = read();
sort(b + 1, b + m + 1, cmp);
sort(a + 1, a + 1 + n);
if (n > m) n = m;
LL ans = 0;
rep(i, 1, n) if (a[i] < b[i]) ans += 1LL * (b[i] - a[i]);
printf("%lld
", ans);
return 0;
}
这样的复杂度是(O(nlog n+mlog m+n+m)),比我的快多了,%%%。