题目大意
(T)((Tleq10^5))组询问
每次给出(n,m,l,r),和(n)个数(a_1,a_2,...,a_n),要找出(m)个可重复的在区间([l,r])的数,使(a_1,a_2,...,a_n)和选出的(m)个数组成的序列期望随机排序得到升序序列的次数最多
输出序列最多期望随机排序几轮,模998244353
(nleq2 imes10^5,sum nleq2 imes10^6,mleq10^7,a_ileq10^9)
题解
假设选出(m)个数后,一轮随机排序得到升序序列的概率为(p)
那么就相当于一轮随机排序期望得到(p)个升序序列
题目想要一个升序序列,那么期望(frac{1}{p})轮随机排序后得到一个升序序列
问题转化为求一轮随机排序后得到升序序列的概率
可以直接用(frac{合法方案数}{总方案数})算
长度为(n+m)的序列的总方案数是((n+m)!)
至于合法方案数,可以看成将序列排序后,交换相等的数得到的序列数,那么设值为(i)的数出现了(b_i)次,就有合法方案数=(sum b_i!)
现在要最大化(frac{1}{frac{sum b_i!}{(n+m)!}}),相当于是最小化(sum b_i!)
发现尽量将(m)个数放到出现次数较小的值会更优
因为假设有两个值(x,y),(b_x>b_y),则有新加入一个(=x)的值,会让答案乘上(b_x+1),而如果新加入(=y)的值,就会使答案乘上(b_y+1),取(y)会更优
那么就可以选(m)次区间([l,r])中出现次数最少的数
但是(m)比较大,考虑另一种统计方式
二分一个值(k),将出现次数少于(k)的值取至出现次数等于(k),判断够不够
因为(n)总共只有(2 imes10^6),所以一个(log)能过
并不会不用二分的方法
代码
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define view(u,k) for(int k=fir[u];~k;k=nxt[k])
#define maxn 200017
#define LL long long
#define D double
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x*f;
}
void write(int x)
{
if(x==0){putchar('0'),putchar('
');return;}
int f=0;char ch[20];
if(x<0)putchar('-'),x=-x;
while(x)ch[++f]=x%10+'0',x/=10;
while(f)putchar(ch[f--]);
putchar('
');
return;
}
int n,m,l,r,fac[maxn+10000000],ny[maxn+10000000],s1[maxn],tp1,s2[maxn],tp2,ans,num;
const LL mod=998244353;
int mul(int x,int y){int res=1;while(y){if(y&1)res=(LL)res*x%mod;x=(LL)x*x%mod,y>>=1;}return res;}
signed main()
{
fac[0]=1;int tm=0;
rep(i,1,maxn-17+10000000)fac[i]=(LL)fac[i-1]*i%mod;
ny[maxn-17+10000000]=mul(fac[maxn-17+10000000],mod-2);
dwn(i,maxn-18+10000000,0)ny[i]=(LL)ny[i+1]*(i+1)%mod;
int t=read();
while(t--)
{
n=read(),m=read(),l=read(),r=read();tp1=tp2=0;ans=1;
rep(i,1,n)
{
int x=read();
if(x>=l&&x<=r)s1[++tp1]=x;
else s2[++tp2]=x;
}
sort(s1+1,s1+tp1+1),sort(s2+1,s2+tp2+1);
s1[0]=l-1,s1[tp1+1]=r+1;int tmp=1,zero=0,L=1,R=ceil((D)(m+tp1)/(D)(r-l+1));
rep(i,0,tp1)if(s1[i+1]!=s1[i])zero+=s1[i+1]-s1[i]-1;
if(zero<m)
{
ans=-1;int ans2=-1,ans3=-1;
while(L<=R)
{
int mid=(L+R)>>1;int tmp=zero*mid,cnt=0,cnt2=0,cnt3=0;num=0;
rep(i,1,tp1)
{
num++;
if(s1[i]!=s1[i+1]){if(num<mid)tmp+=mid-num,cnt2++,cnt3+=num;num=0;}
}
if(tmp>=m){if(mid<ans||ans==-1)ans3=cnt3,ans=mid,ans2=cnt2;R=mid-1;}
else L=mid+1;
}
int num1=ans*(ans2+zero)-(m+ans3);tmp=ans;
ans=(LL)mul(ny[ans],ans2+zero-num1)*mul(ny[ans-1],num1)%mod;
}
num=0;
rep(i,1,tp1){num++;if(s1[i]!=s1[i+1]){if(num>=tmp)ans=(LL)ans*ny[num]%mod;num=0;}}
num=0;
rep(i,1,tp2){num++;if(i==tp2||s2[i]!=s2[i+1])ans=(LL)ans*ny[num]%mod,num=0;} ;
write((LL)ans*fac[n+m]%mod);
}
return 0;
}