链接
题意
一个长度为 (n) 的数组 (a_i),构造 自然数 数组满足:
- (forall i,b_iin[0,a_i]).
- (sum_{i=1}^n b_i=k)
在这个前提下,求令 (sum_{i=1}^n b_i(a_i-b_i^2)) 达到最大值的任意一组 (b_i).
题解
首先很容易想到暴力,设 (b_i) 的值为 (x) ,则 (b_i) 增加 (1) 对答案新增的贡献为 (Delta_i=(x+1)left(a_i-(x-1^2) ight)-x(a_i-x^2)=a_i-3x^2+3x-1)。
那么,初始时设所有 (b_i) 都为 (0) ,用一个堆维护,每次从堆中取出 (Delta_i) 最大的 (b_i) 加 (1),并记录下 (b_i) ,最后得到的 (b_i) 就是题目所求的 (b_i) 。时间复杂度为 (O(k log n)),显然不能承受。
观察上述过程,我们发现在暴力过程中,由于每次都采取当前的最优策略,所以对于任意的 (y),第 (y) 次增加的值是 固定的 ,并且是单调递减的。那么这样的一个增量便具有了单调性。
对于一个 (Delta_i),我们知道如果取这个 (i) 把它对应的 (b_i) 加 (1),意味着其他位置上的 (Delta) 值并不会大于 (Delta_i),也就是说我们可以根据这个 (Delta_i) ,大概确定出所有的 (b_i)。我们称这个 (Delta_i) 为最大增量.
由于满足单调性,最大增量可以被二分出来,( ext{check}) 函数也非常简单,根据最大增量求出每个 (b_i) ,判断 (sum_{i=1}^n b_i) 是否大于 (k) 即可。至于求 (b_i) 的过程可以解方程,也可以直接二分,前者可以近似认为是 (O(n log maxVal)),而后者则是 (O(n log^2 maxVal))。
值得一提的是,由于最大增量 非严格单调递减 ,以及二分后求出的 (b_i) 不一定满足 (sum b_i=k) 的约束,我们规定二分求出的 (b_i) 满足 (sum b_ileq k),并且二分求出的 (Delta) 不是一个值,是一个长度为 (2) 的区间,对边界进行控制,这样就不会遗漏任何情况。
Show the Code
#include<cstdio>
#include<climits>
#define max(a,b) ((a)>(b)? (a):(b))
#define min(a,b) ((a)<(b)? (a):(b))
typedef long long ll;
ll n,m,sum=0;
ll a[100005],b[100005];
inline ll read() {
register ll x=0,f=1;register char s=getchar();
while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
return x*f;
}
//-3*y^2+3*y+x-1=-3(y-(1/2))^2+...
inline ll f(ll x,ll y) {return x==y? LLONG_MAX:x-3*y*(y-1)-1;}//对称轴x=1/2
inline ll calc(ll x,ll lim) {//f(a[x],b[x])<=lim
ll l=1,r=a[x],res=a[x];
while(l<=r) {
int mid=l+r>>1;
if(f(a[x],mid)<=lim) {r=mid-1;res=mid;}
else {l=mid+1;}
}
return res;
}
inline bool check(ll x) {
sum=0;
for(register int i=1;i<=n;++i) {
b[i]=calc(i,x);
sum+=b[i];
}
return sum<m;
}
int main() {
n=read();m=read();
ll l=0,r=0;
for(register int i=1;i<=n;++i) {
a[i]=read();
l=min(l,f(a[i],a[i]-1));
r=max(r,f(a[i],0));
}
while(l+1<r) {
ll mid=l+r>>1;
if(check(mid)) r=mid;
else l=mid;
}
if(check(l)) r=l;
check(r);
m-=sum;
for(register int i=1;i<=n;++i) if(m>0&&b[i]<a[i]&&f(a[i],b[i])==r) ++b[i],--m;
for(register int i=1;i<=n;++i) printf("%lld ",b[i]);
return 0;
}