这是我见过的为数不多的良心九怜题之一。
题目大意
给定一个长度为$n$序列,你要在序列末尾加入$m$个$[L,R]$之间的数$mleq 10^7,L,Rleq 10^9$,使得该序列猴子排序的轮数(一轮是指随机打乱整个序列,不断重复操作直到否有序)期望最大,求这个最大的期望。
题解
假设序列元素互不相同,那么有序的排列方式只有一个,而排列方式的数量有$n!$种,每轮成功的概率是$frac {1}{n!}$,所以期望轮数是$n!$。
考虑第$i$种元素有$cnt_i$个的序列有多少个。
元素是互不相同的,所以$cnt_i$个第$i$个元素在有序序列中的相对位置是固定的,而元素在这些位置的排列是任意的,所有有序的排列方式有$prod cnt_i!$个,期望是$frac {n!}{prod cnt_i!}$。
若使上式子最大,由于$(n+1)!(n-1)!>(n!)^2$很明显希望使得$max cnt_i$最小且平均。
那么我们统计原序列中$[L,R]$之间的数,每次加入数量最小的元素,直到加入了$m$个数,最后计算答案即可,这个排一下序优化一下就行了。
#include<algorithm> #include<iostream> #include<cstring> #include<cstdio> #include<cmath> #define LL long long #define M 200020 #define MAXN 10200010 #define mod 998244353 using namespace std; namespace IO{ const int BS=(1<<20); int Top=0; char Buffer[BS],OT[BS],*OS=OT,*HD,*TL,SS[20]; const char *fin=OT+BS-1; char Getchar(){if(HD==TL){TL=(HD=Buffer)+fread(Buffer,1,BS,stdin);} return (HD==TL)?EOF:*HD++;} void flush(){fwrite(OT,1,OS-OT,stdout);} void Putchar(char c){*OS++ =c;if(OS==fin)flush(),OS=OT;} void write(int x){ if(!x){Putchar('0');return;} if(x<0) x=-x,Putchar('-'); while(x) SS[++Top]=x%10,x/=10; while(Top) Putchar(SS[Top]+'0'),--Top; } int read(){ int nm=0,fh=1; char cw=Getchar(); for(;!isdigit(cw);cw=Getchar()) if(cw=='-') fh=-fh; for(;isdigit(cw);cw=Getchar()) nm=nm*10+(cw-'0'); return nm*fh; } } using namespace IO; int add(int x,int y){return (x+y>=mod)?x+y-mod:x+y;} int mul(int x,int y){return (x==1||y==1)?x+y-1:(LL)x*(LL)y%mod;} int qpow(int x,int sq){ int res=1; for(;sq;sq>>=1,x=mul(x,x)) if(sq&1) res=mul(res,x); return res; } int p[M],fac[MAXN],ifac[MAXN],cnt[M],t[M]; int main(){ fac[0]=1; for(int i=1;i<MAXN;++i) fac[i]=mul(fac[i-1],i); ifac[MAXN-1]=qpow(fac[MAXN-1],mod-2); for(int i=MAXN-1;i;--i) ifac[i-1]=mul(ifac[i],i); for(int T=read();T;--T){ int n=read(),m=read(),l=read(),r=read(); int now=r-l+1,num=0,tot=1,ans,tmp=0,fin; for(int i=1;i<=n;i++) cnt[i]=0,p[i]=read(); sort(p+1,p+n+1),ans=fac[n+m]; for(int i=1;i<=n;i++,tot++){ while(p[i]==p[i+1]&&i<n) i++,cnt[tot]++; cnt[tot]++; if(p[i]>=l&&p[i]<=r) t[++tmp]=cnt[tot],now--; else ans=mul(ans,ifac[cnt[tot]]); } tot--,sort(t+1,t+tmp+1); for(fin=1;fin<=tmp;fin++){ if((LL)(t[fin]-num)*(LL)now>(LL)m) break; m-=(t[fin]-num)*now,now++,num=t[fin]; } num+=m/now,m%=now; ans=mul(ans,mul(qpow(ifac[num+1],m),qpow(ifac[num],now-m))); for(int i=fin;i<=tmp;i++) ans=mul(ans,ifac[t[i]]); write(ans),Putchar(' '); }flush(); return 0; }