先考虑没有区间,即对于长为$n$的序列${a_{1},a_{2},...,a_{n}}$(以下记$a_{0}=a_{n+1}=0$),求$F(a,k)$
问题即构造序列$b_{i}$,满足$forall 0le ile n,b_{i}equiv a_{i}-a_{i+1}(mod k)$且$sum_{i=0}^{n}b_{i}=0$,并最小化$frac{sum_{i=0}^{n}|b_{i}|}{2}$
最优的$b_{i}$满足$|b_{i}|<k$,否则不妨假设$b_{i}ge k$($b_{i}le -k$类似),将其减小$k$,并将一个$b_{j}<0$增加$k$(总存在,否则不满足$sum_{i=0}^{n}b_{i}=0$),显然$frac{sum_{i=0}^{n}|b_{i}|}{2}$严格减小
又因为$b_{i}equiv a_{i}-a_{i+1}(mod k)$,显然其最后的取值仅有两种,不妨都先取较小的一种,再选择$-frac{sum_{i=0}^{n}b_{i}}{k}$个位置增加$k$(取另一种),显然贪心选择收益最高(即$b_{i}$最小)的位置即可
(特别的,当$a_{i}-a_{i+1}equiv 0(mod k)$,令较小的取值为0,另一种取值为$k$,显然不会取到)
下面,问题变为一个区间,我们来维护上面的过程——
更准确的来说,由于$a_{i}in [0,k)$,这个较小的取值即$egin{cases}a_{i}-a_{i+1}&(a_{i}le a_{i+1})\a_{i}-a_{i+1}-k&(a_{i}>a_{i+1})end{cases}$
求出区间中$S=-sum_{i=l-1}^{r}b_{i}$(其中$b_{l-1}$和$b_{r}$要特判),首先答案以$S$为基础上修改,其次要选择$frac{S}{k}$个位置
关于如何找到这个位置,来二分这个$b_{i}$最小,并对两类分别统计(同样特判$b_{l-1}$和$b_{r}$),用可持久化线段树维护,以及再求出区间和即可
(直接在可持久化线段树上二分似乎并不太行,因为有两段)
总复杂度为$o(qlog^{2}n)$,可以通过
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 200005 4 #define ll long long 5 #define mid (l+r>>1) 6 #define pil pair<int,ll> 7 #define fi first 8 #define se second 9 pil f[N*60]; 10 int V,n,m,q,x,y,z,a[N],b[N],rt1[N],rt2[N],ls[N*60],rs[N*60]; 11 int New(int k){ 12 f[++V]=f[k]; 13 ls[V]=ls[k]; 14 rs[V]=rs[k]; 15 return V; 16 } 17 pil add(pil x,pil y){ 18 return make_pair(x.fi+y.fi,x.se+y.se); 19 } 20 pil dec(pil x,pil y){ 21 return make_pair(x.fi-y.fi,x.se-y.se); 22 } 23 void update(int &k,int l,int r,int x){ 24 k=New(k); 25 f[k].fi++,f[k].se+=x; 26 if (l==r)return; 27 if (x<=mid)update(ls[k],l,mid,x); 28 else update(rs[k],mid+1,r,x); 29 } 30 pil query(int k,int l,int r,int x,int y){ 31 if ((!k)||(l>y)||(x>r))return make_pair(0,0); 32 if ((x<=l)&&(r<=y))return f[k]; 33 return add(query(ls[k],l,mid,x,y),query(rs[k],mid+1,r,x,y)); 34 } 35 pil calc(int x,int y,int l,int k){ 36 pil o1=dec(query(rt1[y-1],0,m,1,k-l),query(rt1[x-1],0,m,1,k-l)); 37 pil o2=dec(query(rt2[y-1],0,m,l,k-1),query(rt2[x-1],0,m,l,k-1)); 38 pil o=make_pair(o1.fi+o2.fi,1LL*k*o1.fi-o1.se+o2.se); 39 if ((a[y])&&(a[y]<=k-l)){ 40 o.fi++; 41 o.se+=k-a[y]; 42 } 43 if (l<=a[x]){ 44 o.fi++; 45 o.se+=a[x]; 46 } 47 return o; 48 } 49 int main(){ 50 m=(1<<30)-1; 51 scanf("%d%d",&n,&q); 52 for(int i=1;i<=n;i++)scanf("%d",&a[i]); 53 for(int i=0;i<=n;i++)b[i]=a[i]-a[i+1]; 54 for(int i=1;i<n;i++){ 55 rt1[i]=rt1[i-1]; 56 if (b[i]>0)update(rt1[i],0,m,b[i]); 57 } 58 for(int i=1;i<n;i++){ 59 rt2[i]=rt2[i-1]; 60 if (b[i]<=0)update(rt2[i],0,m,-b[i]); 61 } 62 for(int i=1;i<=q;i++){ 63 scanf("%d%d%d",&x,&y,&z); 64 pil o1=dec(f[rt1[y-1]],f[rt1[x-1]]); 65 pil o2=dec(f[rt2[y-1]],f[rt2[x-1]]); 66 if (!a[y])o2.fi++; 67 else{ 68 o1.fi++; 69 o1.se+=a[y]; 70 } 71 o2.fi++,o2.se+=a[x]; 72 ll ans=1LL*z*o1.fi-o1.se+o2.se; 73 int l=1,r=z-1; 74 while (l<r){ 75 int midd=(l+r+1>>1); 76 if (calc(x,y,midd,z).fi>=ans/z)l=midd; 77 else r=midd-1; 78 } 79 pil o=calc(x,y,l,z); 80 o.se-=(o.fi-ans/z)*l,o.fi=ans/z; 81 ans+=1LL*z*o.fi-2*o.se; 82 printf("%lld ",ans/2); 83 } 84 }