没错这个就是上次跟 wjz 不会做的那个 2500。
有两个截然不同的方法。
法一
首先你读错题了,考虑每个数的个数都是 (3) 的倍数怎么做。那显然一个区间合法当且仅当每个数的 (cntmod 3=0) 都。于是容易对 (cntmod 3) 这个数组哈希,然后在前面找相同的。一个 map
即可实现,然后哈希表可以做到线性。
然后考虑每个数的 (cnt) 恰好是 (0/3)。那就是要去掉那些 (>3) 且被 (3) 整除的。那么显然满足每个数 (cntleq3) 的左端点构成后缀,而且随右端点上升而上升。所以就可以实时维护,然后 two-pointers,减掉即可。
#include<bits/stdc++.h>
using namespace std;
#define mp make_pair
#define X first
#define Y second
#define pb push_back
void read(int &x){
x=0;char c=getchar();
while(!isdigit(c))c=getchar();
while(isdigit(c))x=(x<<1)+(x<<3)+(c^48),c=getchar();
}
const int N=500000;
int n;
int a[N+1];
int cnt[N+1];
int las1[N+1],las2[N+1],las3[N+1];
const int hbase1=131,hmod1=998244353,hbase2=13331,hmod2=1000000007,hshmod=19260817;
int pw1[N+1],pw2[N+1];
struct hashlist{
vector<pair<pair<int,int>,int> > v[hshmod];
#define g v[x.X%hshmod]
int &operator[](pair<int,int> x){
for(int i=0;i<g.size();i++)if(g[i].X==x)return g[i].Y;
g.pb(mp(x,0));
return g.back().Y;
}
}st;
int hsh1[N+1],hsh2[N+1];
int main(){
read(n);
for(int i=1;i<=n;i++)read(a[i]);
pw1[0]=pw2[0]=1;
for(int i=1;i<=n;i++)pw1[i]=1ll*pw1[i-1]*hbase1%hmod1,pw2[i]=1ll*pw2[i-1]*hbase2%hmod2;
int now=0;
st[mp(0,0)]++;
long long ans=0;
memset(las1,-1,sizeof(las1));memset(las2,-1,sizeof(las2));memset(las3,-1,sizeof(las3));
for(int i=1;i<=n;i++){
hsh1[i]=hsh1[i-1],hsh2[i]=hsh2[i-1];
(hsh1[i]-=1ll*cnt[a[i]]*pw1[a[i]]%hmod1)%=hmod1,(hsh2[i]-=1ll*cnt[a[i]]*pw2[a[i]]%hmod2)%=hmod2;
(hsh1[i]+=1ll*(cnt[a[i]]=(cnt[a[i]]+1)%3)*pw1[a[i]]%hmod1)%=hmod1,(hsh2[i]+=1ll*cnt[a[i]]*pw2[a[i]]%hmod2)%=hmod2;
(hsh1[i]+=hmod1)%=hmod1,(hsh2[i]+=hmod2)%=hmod2;
int old_now=now;
now=max(now,las3[a[i]]);
for(int j=old_now;j<now;j++)st[mp(hsh1[j],hsh2[j])]--;
las3[a[i]]=las2[a[i]],las2[a[i]]=las1[a[i]],las1[a[i]]=i;
ans+=st[mp(hsh1[i],hsh2[i])];
st[mp(hsh1[i],hsh2[i])]++;
}
cout<<ans;
return 0;
}
法二
换个角度思考。考虑将左端点看成第一维,右端点看成第二维,这样就是一个矩形了。然后考虑对于每个数都要满足出现次数为 (0/3),那就对于每个数都把符合要求的格子们都标记一下,最后看多少格子被标记了 (n) 遍。
显然每个数的标记的格子们,是 (mathrm O(c)) 个子矩阵,其中 (c) 是当前数的总个数。那么这样就可以随便线性二次对数标记。但是显然会炸。
考虑反面,数不满足至少一个数的格子数。那显然还是 (mathrm O(c)) 个子矩阵。但是刚刚是 (n) 个内部的并求交,现在是求并,要简单很多,直接离线下来扫描线求矩形面积并即可。
代码不想写了(
UPD 晚上:代码不写感觉良心不安,所以就写了(
大概挺好写的。就要标记这些的格子:
- 右端点小于左端点的;
- 只包含 (1) 个当前数的;
- 只包含 (2) 个当前数的;
- 包含 (geq4) 个当前数的。
然后线段树的话要支持区间 (pm1),整体查询 (0) 数量。采用的是 rng 的那种懒标记就可以维护的小清新方法。
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define mp make_pair
#define X first
#define Y second
const int N=500000;
int n;
vector<int> pos[N+1];
vector<pair<pair<int,int>,pair<int,int> > > rect;
struct segtree{
struct node{int l,r,mn,cnt,lz;}nd[N<<2];
#define l(p) nd[p].l
#define r(p) nd[p].r
#define mn(p) nd[p].mn
#define cnt(p) nd[p].cnt
#define lz(p) nd[p].lz
void bld(int l=1,int r=n,int p=1){
l(p)=l;r(p)=r;mn(p)=lz(p)=0;cnt(p)=r(p)-l(p)+1;
if(l==r)return;
int mid=l+r>>1;
bld(l,mid,p<<1),bld(mid+1,r,p<<1|1);
}
void init(){bld();}
void sprup(int p){
if(mn(p<<1)==mn(p<<1|1))mn(p)=mn(p<<1),cnt(p)=cnt(p<<1)+cnt(p<<1|1);
else if(mn(p<<1)<mn(p<<1|1))mn(p)=mn(p<<1),cnt(p)=cnt(p<<1);
else mn(p)=mn(p<<1|1),cnt(p)=cnt(p<<1|1);
}
void sprdwn(int p){
if(lz(p)){
mn(p<<1)+=lz(p),lz(p<<1)+=lz(p);
mn(p<<1|1)+=lz(p),lz(p<<1|1)+=lz(p);
lz(p)=0;
}
}
void add(int l,int r,int v,int p=1){
if(l<=l(p)&&r>=r(p))return mn(p)+=v,lz(p)+=v,void();
sprdwn(p);
int mid=l(p)+r(p)>>1;
if(l<=mid)add(l,r,v,p<<1);
if(r>mid)add(l,r,v,p<<1|1);
sprup(p);
}
int _cnt(){return mn(1)==0?cnt(1):0;}
}segt;
vector<pair<int,int> > add[N+2],del[N+2];
int main(){
cin>>n;
for(int i=1;i<=n;i++){
int x;
scanf("%d",&x);
pos[x].pb(i);
}
for(int i=1;i<=n;i++)rect.pb(mp(mp(i,i),mp(1,i-1)));
for(int i=1;i<=n;i++){
vector<int> &v=pos[i];
for(int j=0;j<v.size();j++)rect.pb(mp(mp(j?v[j-1]+1:1,v[j]),mp(v[j],j+1==v.size()?n:v[j+1]-1)));
for(int j=0;j+1<v.size();j++)rect.pb(mp(mp(j?v[j-1]+1:1,v[j]),mp(v[j+1],j+2==v.size()?n:v[j+2]-1)));
for(int j=0;j+3<v.size();j++)rect.pb(mp(mp(j?v[j-1]+1:1,v[j]),mp(v[j+3],n)));
}
for(int i=0;i<rect.size();i++){
add[rect[i].X.X].pb(rect[i].Y);
del[rect[i].X.Y+1].pb(rect[i].Y);
}
long long ans=0;
segt.init();
for(int i=1;i<=n;i++){
for(int j=0;j<add[i].size();j++)segt.add(add[i][j].X,add[i][j].Y,1);
for(int j=0;j<del[i].size();j++)segt.add(del[i][j].X,del[i][j].Y,-1);
ans+=segt._cnt();
}
cout<<ans;
return 0;
}