题意:给出长度为n的序列,问任两个区间的mex运算结果的总和。
解法:直接讲线段树做法:我们注意到mex(1,1),mex(1,2),mex(1,3)...mex(1,i)的结果是单调不减的,那么我们考虑先用线段树维护上诉结果,那么此时以1为左端点的区间mex和就求出来了,重点来了:我们考虑怎么从以1为左端点的区间结果过渡到以2为结点的区间结果呢?我们注意到其实只要以1为端点的区间去掉a[1]这个点的影响就可以得到以2为端点的区间结果,那么我们怎样去除a[1]这个点的影响呢?我们发现去掉a[1]之后会影响到的就是位置1到下一个a[1]出现位置的这一段区间!这一段区间的结果如果mex>a[1],那么因为a[1]的删除它的结果就会变成a[1]。且我们上面提到mex(1,1)到mex(1,n)的结果是单调不减的。那么我们就可以在线段树上二分来找一个mex>a[1]的点,区间修改即可。这样下去一边统计答案一边删除数修改影响,到最后就可以AC了。
这道题的线段树解法还是比较经典的做法的,对于一类问题:询问的是任两个区间的结果总和,而且发现我们能比较快速地通过删除最前面的数使得结果快速过渡到下一个区间的结果,那么我们可以考虑使用这种像前缀线段树(这个简称是蒟蒻瞎掰的qwq)的做法。
细节详见代码:
#include<bits/stdc++.h> using namespace std; typedef long long LL; const int N=2e5+10; int n,a[N],f[N],nxt[N]; bool vis[N]; map<int,int> mp; LL sum[N<<2],tag[N<<2]; void pushup(int rt) { sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } void pushdown(int rt,int l1,int l2) { if (tag[rt]==-1) return; sum[rt<<1]=(LL)tag[rt]*l1; tag[rt<<1]=tag[rt]; sum[rt<<1|1]=(LL)tag[rt]*l2; tag[rt<<1|1]=tag[rt]; tag[rt]=-1; } void build(int rt,int l,int r) { if (l==r) { sum[rt]=f[l]; tag[rt]=-1; return; } sum[rt]=0; tag[rt]=-1; int mid=l+r>>1; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } void update(int rt,int l,int r,int ql,int qr,int v) { if (ql<=l && r<=qr) { sum[rt]=(LL)v*(r-l+1); tag[rt]=v; return; } int mid=l+r>>1; pushdown(rt,mid-l+1,r-mid); if (ql<=mid) update(rt<<1,l,mid,ql,qr,v); if (qr>mid) update(rt<<1|1,mid+1,r,ql,qr,v); pushup(rt); } LL query(int rt,int l,int r,int ql,int qr) { if (ql<=l && r<=qr) return sum[rt]; int mid=l+r>>1; pushdown(rt,mid-l+1,r-mid); LL ret=0; if (ql<=mid) ret+=query(rt<<1,l,mid,ql,qr); if (qr>mid) ret+=query(rt<<1|1,mid+1,r,ql,qr); return ret; } int main() { while (scanf("%d",&n) && n) { for (int i=1;i<=n;i++) scanf("%d",&a[i]); for (int i=0;i<=n;i++) vis[i]=0; for (int i=1;i<=n;i++) { if (a[i]<=n) vis[a[i]]=1; f[i]=f[i-1]; while (vis[f[i]]) f[i]++; } mp.clear(); for (int i=n;i;i--) { if (mp.count(a[i])) nxt[i]=mp[a[i]]; else nxt[i]=n+1; mp[a[i]]=i; } build(1,1,n); LL ans=0; for (int i=1;i<=n;i++) { ans+=query(1,1,n,i,n); int l=i,r=nxt[i]-1,t=a[i]; while (l<r) { int mid=l+r>>1; if (query(1,1,n,mid,mid)>t) r=mid; else l=mid+1; } if (query(1,1,n,r,r)>t) update(1,1,n,r,nxt[i]-1,a[i]); } printf("%lld ",ans); } return 0; }