problem
求满足(sum_i[p_i=max_{j=1}^i p_j]=a),(sum_i[p_i=max_{j=i}^n p_j]=b)的1到n的排列p的个数。
solution
设f[i,j]为从大到小地向序列中加入i个数,形成了j个前缀最大值的情况,转移有
[egin{aligned}
f[0,0]=1,&&f[i,j]=f[i-1,j-1]+(i-1)f[i-1,j]
end{aligned}
]
显然这恰是第一类斯特林数,即(f[i,j]=s(i,j))。
一个数集与一个操作方案能对应一个序列。考虑枚举数n的位置,那么答案为
[sum_{i=1}^ns(i-1,a-1)s(n-i,b-1) imes C(n-1,i-1)
]
这相当于是把1到n-1给分成a+b-2个环的方案数(其中环有两类,每类分别由a+1个和b+1个)即答案
[s(n-1,a+b-2) imes C(a+b-2,a-1)
]
至此问题已完结。
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+10;
const int mod=998244353;
const int inf=0x3f3f3f3f;
inline ll qpow(ll x,ll y) {
ll c=1;
for(; y; y>>=1,x=x*x%mod)
if(y&1) c=x*c%mod;
return c;
}
int p,pcur,rev[N];
inline void ntt_init(int len) {
for(p=1,pcur=0; p<(len<<1);) p<<=1,pcur++;
for(int i=0; i<p; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(pcur-1));
}
inline void ntt(ll*a,int tp) {
for(int i=0; i<p; ++i) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int m=1; m<p; m<<=1) {
int wm=qpow(3,(mod-1)/(m<<1)); if(tp<0) wm=qpow(wm,mod-2);
for(int i=0; i<p; i+=(m<<1)) { ll w=1,tmp;
for(int j=0; j<m; ++j,w=w*wm%mod) {
tmp=w*a[i+j+m]%mod;
a[i+j+m]=(a[i+j]-tmp+mod)%mod;
a[i+j]=(a[i+j]+tmp)%mod;
}
}
}
if(tp<0) {
ll tmp=qpow(p,mod-2);
for(int i=0; i<p; ++i) a[i]=tmp*a[i]%mod;
}
}
inline void chm(ll*A,ll*B) {
ntt(A,1); ntt(B,1);
for(int i=0; i<p; ++i) (A[i]*=B[i])%=mod;
ntt(A,-1);
}
ll fac[N],fav[N],A[N],B[N];
void calc(int n,ll*s) {
if(n==0) {s[0]=1; return;}
if(n==1) {s[1]=1; return;}
int m(n/2); calc(m,s); ntt_init(m+1);
for(int i=0; i<=m; ++i) A[m-i]=fac[i]*s[i]%mod;
for(int i=0; i<=m; ++i) B[i]=fav[i]*qpow(m,i)%mod;
for(int i=m+1; i<p; ++i) A[i]=B[i]=0;
chm(A,B);
for(int i=0; i<=m; ++i) B[i]=A[m-i]*fav[i]%mod;
for(int i=0; i<=m; ++i) A[i]=s[i];
for(int i=m+1; i<p; ++i) A[i]=B[i]=0;
chm(A,B);
for(int i=0; i<=m+m; ++i) s[i]=A[i];
if(n&1)
for(int i=n; i>=0; --i) s[i]=((i?s[i-1]:0)+(n-1)*s[i]%mod)%mod;
}
ll s[N];
int main() {
fac[0]=fac[1]=fav[0]=fav[1]=1;
for(int i=2; i<N; ++i) fav[i]=fav[mod%i]*(mod-mod/i)%mod;
for(int i=2; i<N; ++i) fav[i]=fav[i-1]*fav[i]%mod,fac[i]=fac[i-1]*i%mod;
//int n; scanf("%d",&n); calc(n,s);
//for(int i=0; i<=n; ++i) printf("s(%d,%d)=%d
",n,i,s[i]);
int n,a,b;
scanf("%d%d%d",&n,&a,&b);
calc(n-1,s);
printf("%lld",fac[a+b-2]*fav[a-1]%mod*fav[b-1]%mod*s[a+b-2]%mod);
return 0;
}