XXIV.CF960G Bandit Blues
我们注意到,\(n\)一定是前缀最大值中最靠右的一个以及后缀最大值中最靠左的一个。换句话说,我们在位置\(n\)可以将整个排列划成两半,前一半中恰有\(a-1\)个前缀最大值,而后一半中恰有\(b-1\)个后缀最大值。
显然两半的问题是相同的,因为后缀最大值在翻转序列后就一定会变成前缀最大值。所以我们只需考虑一个长度为\(i\)的序列中恰有\(j\)个前缀最大值的方案数即可。我们设\(f[i][j]\)表示这一概念。
我们发现如果按照常规思路让状态转移为“加入数\(n\)的操作”会让转移方程非常恶心;但是我们如果改变思路,往序列中加入\(1\)的话,就会是两种可能:
-
\(1\)加在了序列开头,前缀最大值数量加一;
-
\(1\)加在了序列中间,前缀最大值数量不变。
所以我们就可以列出转移方程
我们发现这正是第一类斯特林数\(\left[\begin{matrix}i\\j\end{matrix}\right]\)的转移式。因此我们可以得出结论,即
之后我们就可以统计答案了。
我们枚举\(n\)放在哪个位置,则有答案为
其中两个斯特林数的意义显然;二项式系数的意义是我们从\(n-1\)个未选择的数中选出\(i-1\)个填到前一半中。
我们考虑改为枚举\(i-1\),并且令\(N=n-1,A=a-1,B=b-1\),则有答案为
我们考虑该式的实际意义:第一个斯特林数的意义是\(i\)个数围成\(A\)个环的方案数,第二个斯特林数的意义是剩下\(N-i\)个数围成\(B\)个环的方案数;外面从\(0\)到\(N\)的枚举,正是枚举了有多少个数分到了\(A\)一边,而二项式系数是枚举选出哪些数放到\(A\)那边去。
因此我们最终会发现它的意义就等价于将\(n\)个数分成\(A+B\)个环,并且选出\(A\)个环分到左边去。
而上面那段话写成数学语言就是
则我们最终的答案就是上式的值。二项式系数可以直接套公式求得;但是第一类斯特林数咋求呢?
很遗憾,第一类斯特林数并没有直接的通项公式;但是我们可以通过求出一整行或者一整列的斯特林数来得到单个斯特林数。明显行更为简单,故我们直接求出一整行斯特林数即可。
(附:实际上,在式子推到\(\sum\limits_{i=0}^{N}\left[\begin{matrix}i\\A\end{matrix}\right]\left[\begin{matrix}N-i\\B\end{matrix}\right]\dbinom{N}{i}\)时就可以通过求出列上的第一类斯特林数来得出答案)
代码:
#include<bits/stdc++.h>
using namespace std;
const int N=1<<20;
const int mod=998244353;
const int G=3;
int rev[N],fac[N],inv[N],lim,invlim,pre,suf;
int ksm(int x,int y){
int rt=1;
for(;y;x=(1ll*x*x)%mod,y>>=1)if(y&1)rt=(1ll*rt*x)%mod;
return rt;
}
void NTT(int *a,int tp){
for(int i=0;i<lim;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int md=1;md<lim;md<<=1){
int rt=ksm(G,(mod-1)/(md<<1));
if(tp==-1)rt=ksm(rt,mod-2);
for(int stp=md<<1,pos=0;pos<lim;pos+=stp){
int w=1;
for(int i=0;i<md;i++,w=(1ll*w*rt)%mod){
int x=a[pos+i],y=(1ll*w*a[pos+md+i])%mod;
a[pos+i]=(x+y)%mod;
a[pos+md+i]=(x-y+mod)%mod;
}
}
}
if(tp==-1)for(int i=0;i<lim;i++)a[i]=(1ll*a[i]*invlim)%mod;
}
int A[N],B[N];
void mul(int *a,int *b,int *c,int len){
for(int i=0;i<lim;i++)A[i]=B[i]=0;
for(int i=0;i<=len;i++)A[i]=a[i],B[i]=b[i];
NTT(A,1),NTT(B,1);
for(int i=0;i<lim;i++)A[i]=1ll*A[i]*B[i]%mod;
NTT(A,-1);
for(int i=0;i<=(len<<1);i++)c[i]=A[i];
}
int a[N],b[N],c[N];
void func(int m){
if(!m){a[0]=1;return;}
int half=(m>>1);
func(half);
int LG=0;
while((1<<LG)<=(half<<1))LG++;
lim=(1<<LG),invlim=ksm(lim,mod-2);
for(int i=0;i<lim;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(LG-1));
b[half]=1;
for(int i=half-1;i>=0;i--)b[i]=1ll*b[i+1]*half%mod;
for(int i=0;i<=half;i++)b[i]=1ll*b[i]*inv[half-i]%mod;
for(int i=0;i<=half;i++)c[i]=1ll*fac[i]*a[i]%mod;
mul(b,c,b,half);
for(int i=0;i<=half;i++)b[i]=1ll*b[i+half]*inv[i]%mod;
mul(b,a,a,half);
if(!(m&1))return;
a[m]=0;
for(int i=m;i;i--)a[i]=(1ll*a[i]*(m-1)%mod+a[i-1])%mod;
a[0]=1ll*a[0]*(m-1)%mod;
}
int n;
int main(){
scanf("%d%d%d",&n,&pre,&suf),n--;
if(n<pre+suf-2){puts("0");return 0;}
fac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
inv[n]=ksm(fac[n],mod-2);
for(int i=n-1;i>=0;i--)inv[i]=1ll*inv[i+1]*(i+1)%mod;
func(n);
printf("%d\n",1ll*a[pre+suf-2]*fac[pre+suf-2]%mod*inv[pre-1]%mod*inv[suf-1]%mod);
return 0;
}