http://acm.hdu.edu.cn/showproblem.php?pid=4747
设我们输入的数组为 a[],我们需要从 1 到 n 遍历, 假设遍历到 i 时, 遍历的过程中用b[j]表示从 i 到 j 没出现的最小自然数
先从 n 到 1 扫一遍求出从 1 到各个点的b[j]值
然后遍历a[] 实际上就是不断的把当前a[i] 去掉,比如说去掉a[3]时,剩下的b[4]---b[n] 就表示从4到其他后续点形成的区间中没出现的最小自然数
要知道从 i 到 n ,b[]的值始终是单调递增的
我们每去掉当前a[i]会对b[]数组产生影响,
设下一个和a[i]相等的数出现的位置是 r 那么去掉a[i] 对 r 以及 r 以后的b[] 没有影响
在 i 和 r 之间受影响的段b[]是大于等于a[i]的那一段 假设是(l,r), 这个段内的b[]都大于等于a[i]
去掉a[i]的影响就是这个段内的b[] 都要等于 a[i]
找到r可以事先标记,找 l 和更新段 (l,r) 有两种方法
1,二分找到 l ,然后遍历更新段 (l,r) 这样代码比较短,也比较易懂,但比较耗时,不过可以过
2,线段树维护 这样代码量会比较大,不过耗时少,线段树的解法应该比较标准
两种代码:
#include<iostream> #include<cstdio> #include<algorithm> #include<string> #include<cstring> #include<cmath> #include<set> #include<vector> #include<list> #include<stack> #include<queue> #include<map> using namespace std; typedef long long ll; typedef pair<int,int> pp; const int INF=0x3f3f3f3f; const int N=200002; bool exist[N]; int a[N],next[N],f[N]; int b[N]; int bsh(int l,int r,int k) { while(l<=r) { int mid=(l+r)>>1; if(b[mid]<=k) l=mid+1; else r=mid-1; } return r; } int main() { //freopen("data.in","r",stdin); int n; while(scanf("%d",&n)!=EOF) { if(n==0) break; for(int i=1;i<=n;++i) scanf("%d",&a[i]); for(int i=0;i<=n;++i) f[i]=n+1; for(int i=n;i>=1;--i) if(a[i]<n) { next[i]=f[a[i]]; f[a[i]]=i; } ll ans=0; memset(exist,false,sizeof(exist)); ll tmp=0;int l=0; for(int i=1;i<=n;++i) { if(a[i]<n) { exist[a[i]]=true; while(exist[l]) ++l; } b[i]=l; tmp+=b[i]; } ans=tmp; for(int i=1;i<n;++i) { if(a[i]<n) { int r=next[i]; int l=bsh(i,r-1,a[i]); for(int j=l+1;j<r;++j) { tmp-=(b[j]-a[i]); b[j]=a[i]; } } tmp-=b[i]; ans+=tmp; } cout<<ans<<endl; } return 0; } #include<iostream> #include<cstdio> #include<algorithm> #include<string> #include<cstring> #include<cmath> #include<set> #include<vector> #include<list> #include<stack> #include<queue> #include<map> using namespace std; typedef long long ll; typedef pair<int,int> pp; const int INF=0x3f3f3f3f; const int N=200002; bool exist[N]; int a[N],next[N],f[N]; int b[N]; struct node { int l,r,k,least; ll sum; }tr[N*4]; void build(int x,int l,int r) { tr[x].l=l;tr[x].r=r;tr[x].k=-1; if(l==r) { tr[x].least=b[l]; tr[x].sum=b[l]; return ; } int mid=(l+r)>>1; build((x<<1),l,mid); build((x<<1)|1,mid+1,r); tr[x].least=min(tr[x<<1].least,tr[(x<<1)|1].least); tr[x].sum=(tr[x<<1].sum+tr[(x<<1)|1].sum); } void update(int x,int l,int r,int k) { if(l>r) return ; if(tr[x].l==l&&tr[x].r==r) { tr[x].least=k; tr[x].k=k; tr[x].sum=(ll)k*(tr[x].r-tr[x].l+1); return ; } if(tr[x].k!=-1) { tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k; tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1); tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k; tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1); tr[x].k=-1; } int mid=(tr[x].l+tr[x].r)>>1; if(r<=mid) update(x<<1,l,r,k); else if(l>mid) update((x<<1)|1,l,r,k); else { update(x<<1,l,mid,k); update((x<<1)|1,mid+1,r,k); } tr[x].least=min(tr[x<<1].least,tr[(x<<1)|1].least); tr[x].sum=(tr[x<<1].sum+tr[(x<<1)|1].sum); tr[x].k=-1; } int get(int x,int l,int r,int w) { if(tr[x].l==tr[x].r) { if(tr[x].least>w) return (l-1); return l; } if(tr[x].k!=-1) { tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k; tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1); tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k; tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1); tr[x].k=-1; } int mid=(tr[x].l+tr[x].r)>>1; if(r<=mid) return get(x<<1,l,r,w); else if(l>mid) return get((x<<1)|1,l,r,w); else { if(tr[(x<<1)|1].least<=w) return get((x<<1)|1,mid+1,r,w); else return get(x<<1,l,mid,w); } } ll gsum(int x,int l,int r) { if(l>r) return 0; if(tr[x].l==l&&tr[x].r==r) return tr[x].sum; if(tr[x].k!=-1) { tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k; tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1); tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k; tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1); tr[x].k=-1; } int mid=(tr[x].l+tr[x].r)>>1; if(r<=mid) return gsum(x<<1,l,r); else if(l>mid) return gsum((x<<1)|1,l,r); else return gsum(x<<1,l,mid)+gsum((x<<1)|1,mid+1,r); } int main() { int n; while(scanf("%d",&n)!=EOF) { if(n==0) break; for(int i=1;i<=n;++i) scanf("%d",&a[i]); for(int i=0;i<=n;++i) f[i]=n+1; for(int i=n;i>=1;--i) if(a[i]<n) { next[i]=f[a[i]]; f[a[i]]=i; } ll ans=0; memset(exist,false,sizeof(exist)); int l=0; for(int i=1;i<=n;++i) { if(a[i]<n) { exist[a[i]]=true; while(exist[l]) ++l; } b[i]=l; } build(1,1,n); ans+=gsum(1,1,n); for(int i=1;i<n;++i) { if(a[i]<n) { int r=next[i]; int l=get(1,i,r-1,a[i]); update(1,l+1,r-1,a[i]); } ans+=gsum(1,i+1,n); } cout<<ans<<endl; } return 0; }