通过打表证明发现答案就是把序列划分成若干段,每段的b都是这一段a的平均数。50分做法比较显然,就是单调栈维护,每次将新元素当成一个区间插入末尾,若b值不满足单调不降,则将这个区间与单调栈前一个区间合并。
由于题目要求每次只修改一个数,所以可以前后缀拼起来,单调栈要改变,然后发现这个显然满足二分的性质,二分完位置左端点后再二分右端点,写一个可持久化单调栈维护一下就可以了。
还有一种主席树做法,后序可能会补上。
#include<bits/stdc++.h> using namespace std; typedef pair<int,int>pii; typedef long long ll; const int N=1e5+7,mod=998244353; int n,m,sum,tp1,tp2,a[N],f[N],g[N],inv[N],ans[N],st1[N],st2[N]; ll s[N]; vector<int>G[N]; vector<pii>q[N]; bool cmp(int l,int r,int L,int R,int x,int y) {return (s[r]-s[l-1]+x)*(R-L+1)>(s[R]-s[L-1]+y)*(r-l+1);} int calc(int l,int r,int x) {return mod-(s[r]-s[l-1]+x)%mod*((s[r]-s[l-1]+x)%mod)%mod*inv[r-l+1]%mod;} int main() { scanf("%d%d",&n,&m); inv[1]=1;for(int i=2;i<=n;i++)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod; for(int i=1;i<=n;i++)scanf("%d",&a[i]),s[i]=s[i-1]+a[i],sum=(sum+1ll*a[i]*a[i])%mod; q[1].push_back(pii(a[1],0)),ans[0]=sum; for(int i=1,x,y;i<=m;i++) scanf("%d%d",&x,&y),q[x].push_back(pii(y,i)),ans[i]=(sum+1ll*(mod-a[x])*a[x]+1ll*y*y)%mod; for(int i=1;i<=n;i++) { while(tp1&&cmp(st1[tp1-1]+1,st1[tp1],st1[tp1]+1,i,0,0))G[i].push_back(st1[tp1--]); st1[++tp1]=i,f[tp1]=(f[tp1-1]+calc(st1[tp1-1]+1,i,0))%mod; } st2[0]=n+1; for(int i=n;i;i--) { tp1--; reverse(G[i].begin(),G[i].end()); for(int j=0;j<G[i].size();j++) st1[++tp1]=G[i][j],f[tp1]=(f[tp1-1]+calc(st1[tp1-1]+1,G[i][j],0))%mod; if(i<n) { while(tp2&&cmp(i+1,st2[tp2]-1,st2[tp2],st2[tp2-1]-1,0,0))tp2--; st2[++tp2]=i+1,g[tp2]=(g[tp2-1]+calc(i+1,st2[tp2-1]-1,0))%mod; } for(int j=0;j<q[i].size();j++) { int x=q[i][j].first,y=q[i][j].second,l=1,r=tp1,mid,now=0,d=x-a[i]; while(l<=r) { mid=l+r>>1; if(cmp(st1[mid-1]+1,st1[mid],st1[mid]+1,i,0,d))r=mid-1; else l=mid+1,now=mid; } if(!tp2||!cmp(st1[now]+1,i,i+1,st2[tp2-1]-1,d,0)) ans[y]=(1ll*ans[y]+calc(st1[now]+1,i,d)+f[now]+g[tp2])%mod; else{ l=0,r=tp2-1; int ret=0,cur=0; while(l<=r) { mid=l+r>>1; int L=1,R=now,Mid,pos=0; while(L<=R) { Mid=L+R>>1; if(cmp(st1[Mid-1]+1,st1[Mid],st1[Mid]+1,st2[mid]-1,0,d))R=Mid-1; else L=Mid+1,pos=Mid; } if(mid&&cmp(st1[pos]+1,st2[mid]-1,st2[mid],st2[mid-1]-1,d,0))r=mid-1; else l=mid+1,ret=mid,cur=pos; } ans[y]=(1ll*ans[y]+calc(st1[cur]+1,st2[ret]-1,d)+f[cur]+g[ret])%mod; } } } for(int i=0;i<=m;i++)printf("%d ",ans[i]); }