「NOI2018」冒泡排序
考虑冒泡排序中一个位置上的数向左移动的步数 (Lstep) 为左边比它大的数的个数,向右移动的步数 (Rstep) 为右边比它大的数的个数,如果 (Lstep,Rstep) 中有一个不为 (0) ,那么显然不会取到下界,因为产生了浪费的步数,题面给的提示在这里非常有用,如果至少有一个为 (0) ,那么显然没有产生浪费操作,取到下界,所以一个合法排列的充要条件就是对于所有位置 (Lstep imes Rstep=0) ,即该排列的最长下降子序列长度 (leq 2) 。
先不考虑字典序的限制,只考虑求出一个合法的排列,记 (dp_{i,j}) 为前 (i) 个数,后面数中有 (j) 个比前 (i) 个数的最大值要小,此时前 (i) 位是一个合法排列的方案数,那么考虑这一步如果选一个小于最大值的数,一定要选最小的数,否则就会出现长度 (>2) 的最长下降子序列,否则可以随便选,那么 (dp_{i,j}) 可以转移到 (dp_{i+1,k},j-1leq kleq n-i-1) 。考虑加上字典序的限制,相当于对每一次转移到的 (k) 做一个下界限制,稍微改一改就得到了一个 (mathcal O(n^2)) 的 80分做法,这么简单的套路去年考的时候居然没想到。
其实每次 (k) 的取值是 (geq -1) 的任何数,因为如果 (k > n -i+1) 的话,就再也转移不回 (dp_{n,0}) 了,对答案没有影响,然后把每次取的 (k) 都加 (1) ,问题就转化为 ((0,0)) 到 ((n,n)) 不能低于 (y=-1) 的一个格路计数问题了,此时不加上字典序的限制就是卡特兰数,加上字典序的限制就枚举再哪里超过了字典序的限制,然后的方案数也是可以 (O(1)) 算的,类似于卡特兰数的推导。
code
/*program by mangoyang*/
#pragma GCC optimize("Ofast", "inline")
#include <bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 1200005, mod = 998244353;
int a[N], mx[N], mn[N], js[N], lim[N], inv[N], n;
inline void up(int &x, int y){
x = x + y >= mod ? x + y - mod : x + y;
}
inline int Pow(int a, int b){
int ans = 1;
for(; b; b >>= 1, a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline int C(int x, int y){
if(x < y || x < 0 || y < 0) return 0;
return 1ll * js[x] * inv[y] % mod * inv[x-y] % mod;
}
inline int calc(int x, int y){
int res = 0;
up(res, C(n - x + n - y, n - x));
up(res, mod - C(n - x + n - y, n - y - 1));
return res;
}
namespace Bit{
int s[N];
inline void add(int x){
for(int i = x; i <= n; i += i & -i) s[i]++;
}
inline int query(int x){
int res = 0;
for(int i = x; i; i -= i & -i) res += s[i];
return res;
}
}
inline void solve(){
read(n);
for(int i = 1; i <= n; i++) read(a[i]);
mn[n] = a[n];
for(int i = n - 1; i >= 1; i--) mn[i] = min(a[i], mn[i+1]);
mx[1] = a[1];
for(int i = 2; i <= n; i++) mx[i] = max(mx[i-1], a[i]);
for(int i = n; i >= 1; i--)
lim[i] = Bit::query(mx[i]), Bit::add(a[i]);
for(int i = 1; i <= n; i++) lim[i] += i;
int res = 0;
for(int i = 1; i <= n; i++){
if(lim[i] < n) up(res, calc(i - 1, lim[i] + 1));
if(lim[i-1] > lim[i]) break;
if(lim[i-1] == lim[i] && mn[i] < a[i]) break;
}
cout << res << endl;
for(int i = 0; i <= n; i++) Bit::s[i] = 0;
}
int main(){
js[0] = inv[0] = 1;
for(int i = 1; i < N; i++){
js[i] = 1ll * js[i-1] * i % mod;
inv[i] = Pow(js[i], mod - 2);
}
int T; read(T); while(T--) solve();
}