先是线段树
可以知道mex(i,i),mex(i,i+1)到mex(i,n)是递增的。
首先很容易求得mex(1,1),mex(1,2)......mex(1,n)
因为上述n个数是递增的。
然后使用线段树维护,需要不断删除前面的数。
比如删掉第一个数a[1]. 那么在下一个a[1]出现前的 大于a[1]的mex值都要变成a[1]
因为是单调递增的,所以找到第一个 mex > a[1]的位置,到下一个a[1]出现位置,这个区间的值变成a[1].
然后需要线段树实现区间修改和区间求和
#include <stdio.h> #include <string.h> #include <iostream> #include <algorithm> #include <vector> #include <queue> #include <set> #include <map> #include <string> #include <math.h> #include <stdlib.h> #include <time.h> using namespace std; const int MAXN = 200010; struct Node { int l, r; long long sum;//区间和 int mx;//最大值 int lazy;//懒惰标记,表示赋值为相同的 }segTree[MAXN * 3]; void push_up(int i) { if (segTree[i].l == segTree[i].r) { return; } segTree[i].sum = segTree[i << 1].sum + segTree[(i << 1) | 1].sum; segTree[i].mx = max(segTree[i << 1].mx, segTree[(i << 1) | 1].mx); } void Update_Same(int i, int v) { segTree[i].sum = (long long)v * (segTree[i].r - segTree[i].l + 1); segTree[i].mx = v; segTree[i].lazy = 1; } void push_down(int i) { if (segTree[i].l == segTree[i].r) { return; } if (segTree[i].lazy) { Update_Same(i << 1, segTree[i].mx); Update_Same((i << 1) | 1, segTree[i].mx); segTree[i].lazy = 0; } } int mex[MAXN]; void Build(int i, int l, int r) { segTree[i].l = l; segTree[i].r = r; segTree[i].lazy = 0; if (l == r) { segTree[i].mx = mex[l]; segTree[i].sum = mex[l]; return; } int mid = (l + r) >> 1; Build(i << 1, l, mid); Build((i << 1) | 1, mid + 1, r); push_up(i); } //将区间[l,r]的数都修改为v void Update(int i, int l, int r, int v) { if (segTree[i].l == l && segTree[i].r == r) { Update_Same(i, v); return; } push_down(i); int mid = (segTree[i].l + segTree[i].r) >> 1; if (r <= mid) { Update(i << 1, l, r, v); } else if (l > mid) { Update((i << 1) | 1, l, r, v); } else { Update(i << 1, l, mid, v); Update((i << 1) | 1, mid + 1, r, v); } push_up(i); } //得到值>= v的最左边位置!!!!!!!!!!!!!!!!!!!重要 int Get(int i, int v) { if (segTree[i].l == segTree[i].r) { return segTree[i].l; } push_down(i); if (segTree[i << 1].mx > v) { return Get(i << 1, v); } else { return Get((i << 1) | 1, v); } } int a[MAXN]; map<int, int>mp; int nextt[MAXN]; int main() { //freopen("in.txt","r",stdin); //freopen("out.txt","w",stdout); int n; while (~scanf("%d", &n) && n) { for (int i = 1; i <= n; i++) { scanf("%d", &a[i]); } mp.clear(); int tmp = 0; for (int i = 1; i <= n; i++) //先扫一遍得出1-N的MEX 因为是递增的所以tmp初始化一次就行 { mp[a[i]] = 1; while (mp.find(tmp) != mp.end()) { tmp++; } mex[i] = tmp; cout << tmp << " "; } cout<<endl; mp.clear(); for (int i = n; i >= 1; i--) { if (mp.find(a[i]) == mp.end()) //如果找不到后面存在过的 { nextt[i] = n + 1; } else { nextt[i] = mp[a[i]]; } mp[a[i]] = i; } for(int i=1;i<=n;i++) cout<<nextt[i]<<" "; cout<<endl; Build(1, 1, n); long long sum = 0; for (int i = 1; i <= n; i++) { sum += segTree[1].sum; if (segTree[1].mx > a[i]) { int l = Get(1, a[i]); int r = nextt[i]; if (l < r) { Update(1, l, r - 1, a[i]);//根据分析 l~r-1(下个a[i]出现之前)都要变成a[i]; } } Update(1, i, i, 0); } printf("%I64d ", sum); } return 0; }
然后是DP!!
首先要明白,以i结束的所有区间的值的和记为f[i]肯定不超过以i+1结束的所有区间的值的和记为f[i+1]。
所以可以根据f[i]间接推出f[i+1],记第i个数为sa[i],显然只用考虑大于等于sa[i]的数j对f[i]=f[i-1]+?的影响,。如果j出现在1~i-1区间中,比较j最晚出现的位置与覆盖完全的1~j-1的最小位置的较小位置k,那么区间j的前一次出现的位置到k位置这个区间内所有点到i位置的值都+1.
这样逐次累加,直到不影响为止。
#include<iostream> #include<cmath> #include<cstdio> #include<cstdlib> #include<string> #include<cstring> #include<algorithm> #include<vector> #include<map> #include<set> #include<stack> #include<list> #include<queue> #include<ctime> #define eps 1e-6 #define INF 0x3fffffff #define PI acos(-1.0) #define ll __int64 #define lson l,m,(rt<<1) #define rson m+1,r,(rt<<1)|1 #pragma comment(linker, "/STACK:1024000000,1024000000") using namespace std; #define Maxn 210000 int sa[Maxn],pos[Maxn],full[Maxn]; int main() { //freopen("in.txt","r",stdin); //freopen("out.txt","w",stdout); int n; while(scanf("%d",&n)&&n) { for(int i=1;i<=n;i++) scanf("%d",&sa[i]); memset(pos,0,sizeof(pos)); memset(full,0,sizeof(full)); int last; ll tt=0,ans=0; for(int i=1;i<=n;i++) { if(sa[i]<n)// { last=pos[sa[i]];//前一个sa[i]的最晚位置 pos[sa[i]]=i; //最晚位置 for(int j=sa[i];j<n;j++) { if(j) //考虑j对前面区间的影响 full[j]=min(full[j-1],pos[j]); // else full[j]=i; if(full[j]>last) tt+=full[j]-last; //last+1到full[j]区间内所有点到i的值+1,逐次累加 else break; } } printf("i:%d %I64d ",i,tt); ans+=tt; } printf("%I64d ",ans); } return 0; }