题意
给定一个序列,支持单点修改,查询有多少个子区间满足区间内元素互不相同。
题解
我们记数组$last_i$表示上一个与第$i$个元素相同的位置,所以一定有$last_i<i$。
一个区间$[L,R]$合法当且仅当$last_i<L(iin [L,R])$。
所以对于一个固定的右端点$R$,它对答案的贡献一定是$R-maxspace last_i(ileq R)$。
所以每次询问的答案就是$frac {n(n+1)}{2}-sumlimits_{i=1}^{n} maxspace last_jspace(0<jleq i)$
不难发现$maxspace last_jspace(0<jleq i)$是一个前缀的最大值,我们只需要用线段树维护单调栈即可。
具体做法是每次通过左右儿子来更新当前节点时分类讨论。
首先由于从前向后维护单调栈,那么左侧部分$L$一定会直接对答案产生贡献,而右半部分$R$从中间拆开成$ls,rs$,分两种情况讨论。
1、$ls$最大值小于等于$L$的最大值,那么$ls$一定在单调栈中可以完全被$L$所代替,那么$L$的对最终和的贡献只有$L$的最大值乘以$ls$的区间长度,递归处理$rs$即可。
2、若$ls$的最大值已然大于$L$的最大值了,那么单调栈中所有$rs$的部分一定会完整地贡献给答案,然后递归处理$ls$即可。注意,这里计算$Ans_{rs}$的时候不能直接使用$Ans_{rs}$,而是要使用$Ans_{R}-Ans_{ls}$来更新,原因是,$rs$的答案并非直接贡献给了$R$,具体的可以看下面这张图红色部分提到的一种可能性。
在修改时,对于每一种颜色再开一个$set$,方便每次修改时直接求颜色内该位置的前驱后继,找出哪些地方的$last_i$改变了更新即可。
复杂度为$O(nlog^2n)$。
#include<algorithm> #include<iostream> #include<cstring> #include<cstdio> #include<cmath> #include<set> #define LL long long #define M 200020 #define mid ((l+r)>>1) using namespace std; int read(){ int nm=0,fh=1; char cw=getchar(); for(;!isdigit(cw);cw=getchar()) if(cw=='-') fh=-fh; for(;isdigit(cw);cw=getchar()) nm=nm*10+(cw-'0'); return nm*fh; } set<int>col[M]; int n,m,val[M],p[M<<2],c[M],v[M]; LL sum[M<<2],tot; LL calc(int x,int l,int r,int maxn){ if(p[x]<=maxn) return (LL)(r-l+1)*(LL)maxn; if(l==r) return sum[x]; if(p[x<<1]>maxn) return calc(x<<1,l,mid,maxn)+sum[x]-sum[x<<1]; else return (LL)(mid-l+1)*(LL)maxn+calc(x<<1|1,mid+1,r,maxn); } void pushup(int x,int l,int r){ p[x]=max(p[x<<1],p[x<<1|1]); sum[x]=sum[x<<1]+calc(x<<1|1,mid+1,r,p[x<<1]); } void build(int x,int l,int r){ if(l==r){sum[x]=p[x]=v[l];return;} build(x<<1,l,mid),build(x<<1|1,mid+1,r),pushup(x,l,r); } void change(int x,int l,int r,int pos,int num){ if(l==r){sum[x]=p[x]=num;return;} if(pos<=mid) change(x<<1,l,mid,pos,num); else change(x<<1|1,mid+1,r,pos,num); pushup(x,l,r); } int main(){ n=read(),tot=(LL)n*(LL)(n+1),tot>>=1; set<int>::iterator it,pre,suf; for(int i=1;i<=n;i++) col[i].insert(0); for(int i=1;i<=n;i++){ c[i]=read(),pre=col[c[i]].end(); pre--,v[i]=*pre,col[c[i]].insert(i); } build(1,1,n); for(int T=read();T;T--){ if(!read()){printf("%lld ",tot-sum[1]);continue;} int pos=read(),num=read(); suf=pre=col[c[pos]].find(pos),pre--,suf++; if(suf!=col[c[pos]].end()) change(1,1,n,*suf,*pre); col[c[pos]].erase(pos),col[c[pos]=num].insert(pos); suf=pre=col[num].find(pos),pre--,suf++,change(1,1,n,pos,*pre); if(suf!=col[num].end()) change(1,1,n,*suf,pos); } return 0; }