这道题是真的巧妙,像我这样的估计永远想不出来咋做了。
这道题要求的是严格递增,我们先考虑比较简单一点的,改成严格不下降的。
这样的话,对于原序列a,如果它自身某一段是不下降序列的话,那么我们直接让b与之相等即可,那如果是下降的序列呢?
我们可以说明,这段下降序列的答案一定是序列中所有数的中位数。我们先考虑只有两个数的情况,这个时候一定是中位数(其实只要是两个数中间任意一个值都行),之后我们把这个规律拓展,差不多就能明白答案一定是中位数了。(其实好像取到一个区间之内的值都行,但是中位数肯定是在里面的)
所以说,我们现在就知道了对于每一段区间,区间内答案相同的情况,那我们如何合并呢?
假设我们现在已经有一个答案序列b满足不递减,如果新插入的一项仍然满足不递减的话,那么咱就不用管了,如果递减的话,那我们就用刚才的操作,我们找到前一段答案相同的区间,把两者合并,新的答案变为区间中位数,直到满足答案不递减为止。
如何维护中位数?只要维护一个大根堆,一旦堆中元素个数超过了区间长度一半就开始把大元素全给删了,直到符合条件为止。
你可能和我一样有疑问,直接删除?那会不会在后来你删除的一个数成为了中位数?
并不会。
首先,一个被删除的数如果想成为中位数,那么后面必须有更大的数加入,但是如果是更大的数的话,我们根本不会进行合并,而如果加入的是更小的数的话,那么中位数肯定是在越变越小,那就更不会用到这个被删除的数了。(对于一段被固定的区间,中位数不会变大!)
或许你还有疑问,就是如果我先插入一个较大的数,然后再插入一个较小的数,两者合并之后,这样会不会影响到中位数?
答案还是不会,这个就是bin哥的理解了(%%%prophetB),如果你想要通过这种方法影响中位数的话,那么你必然是插入两个数,一个较小,一个较大,但是这样的话其实并不会对中位数的位置产生影响,所以该被删除的还是会被删除,没什么影响。用bin哥的原话说,就是因为每次只插入一个数,对中位数只有一个数的影响,所以不会改变什么。
所以我们就可以愉快的使用左偏树来维护这个合并和删除的操作啦!
哦,说了一大堆,我们这讨论的都是不递减的情况,题目要求递增,怎么办呢?这个很容易,我们把输入的每个数分别减去它的下标,之后我们只要照着不递减的情况做就可以了。
看一下代码。
// luogu-judger-enable-o2 #include<cstdio> #include<algorithm> #include<cstring> #include<iostream> #include<cmath> #include<set> #include<queue> #define rep(i,a,n) for(int i = a;i <= n;i++) #define per(i,n,a) for(int i = n;i >= a;i--) #define enter putchar(' ') #define lowbit(x) x & (-x) using namespace std; typedef long long ll; const int M = 2000005; const int INF = 1000000009; ll read() { ll ans = 0,op = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') op = -1; ch = getchar(); } while(ch >= '0' && ch <= '9') { ans *= 10; ans += ch - '0'; ch = getchar(); } return ans * op; } ll n,a[M],ans,v[M]; int l[M],r[M],lc[M],rc[M],size[M],tot,cnt,dis[M],root[M]; void pushup(int x) { size[x] = size[lc[x]] + size[rc[x]] + 1; } int merge(int x,int y) { if(!x || !y) return x | y; if(v[x] < v[y]) swap(x,y); rc[x] = merge(rc[x],y); if(dis[rc[x]] > dis[lc[x]]) swap(lc[x],rc[x]); dis[x] = dis[rc[x]] + 1; pushup(x); return x; } int del(int x) { return merge(lc[x],rc[x]); } int insert(int x) { v[++tot] = x,size[tot] = 1; lc[tot] = rc[tot] = dis[tot] = 0; return tot; } int main() { n = read(); rep(i,1,n) a[i] = read(),a[i] -= i; rep(i,1,n) { root[++cnt] = insert(a[i]),l[cnt] = r[cnt] = i; while(cnt > 1 && v[root[cnt]] < v[root[cnt-1]]) { cnt--,root[cnt] = merge(root[cnt],root[cnt+1]); r[cnt] = r[cnt+1]; while((size[root[cnt]] << 1) >= r[cnt] - l[cnt] + 3) root[cnt] = del(root[cnt]); } } rep(i,1,cnt) rep(j,l[i],r[i]) ans += abs(v[root[i]] - a[j]); printf("%lld ",ans); rep(i,1,cnt) rep(j,l[i],r[i]) printf("%lld ",v[root[i]] + j); return 0; }