题目描述
(CF)题面:https://codeforces.com/problemset/problem/960/G。
洛谷题面(带翻译):https://www.luogu.org/problemnew/show/CF960G。
Solution
考虑序列可以被前缀(后缀)最大值分成(a+b-2)个块,注意我们忽略了中间的大小为(n)的数。
设这些最大值为(p_i),那么每个块就是([p_i,p_{i+1}-1])。
注意到我们可以随意分配,每次还要把最大的放在最前面(最后面),所以可以注意到这是个圆排列,所以分成这么多块的方案数就是(s(n-1,a+b-2)),(s)为第一类斯特林数。
然后我们要把(a-1)的块放在前面,所以乘上一个组合数,答案就是:
[s(n-1,a+b-2)cdot inom{a+b-2}{a-1}
]
斯特林数可以分治(FFT)求,复杂度(O(nlog ^2 n))。
第一类斯特林数的求法如下,我们可以构造生成函数:
[prod _{i=0}^{n-1}(x+i)
]
那么这个生成函数的第(k)项就是(s(n,k))。
#include<bits/stdc++.h>
using namespace std;
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('
');}
#define lf double
#define ll long long
const int maxn = 6e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;
int qpow(int a,int x) {
int res=1;
for(;x;x>>=1,a=1ll*a*a%mod) if(x&1) res=1ll*res*a%mod;
return res;
}
int f[maxn],a,b,n,w[maxn],rw[maxn],pos[maxn],N,mxn,bit,fac[maxn];
void prepare() {
w[0]=1,w[1]=qpow(3,(mod-1)/mxn);
for(int i=2;i<=mxn;i++) w[i]=1ll*w[i-1]*w[1]%mod;
rw[0]=1,rw[1]=qpow(qpow(3,mod-2),(mod-1)/mxn);
for(int i=2;i<=mxn;i++) rw[i]=1ll*rw[i-1]*rw[1]%mod;
}
void ntt(int *r,int op) {
for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++) {
int x=r[j+k],y=1ll*r[i+j+k]*(op==1?w:rw)[k*d]%mod;
r[j+k]=(x+y)%mod,r[i+j+k]=(x-y+mod)%mod;
}
if(op==-1) {
int inv=qpow(N,mod-2);
for(int i=0;i<N;i++) r[i]=1ll*r[i]*inv%mod;
}
}
int tmp[18][maxn],tmp1[maxn],tmp2[maxn];
void solve(int l,int r,int d) {
if(l==r) return tmp[d][0]=l,tmp[d][1]=1,void();
int mid=(l+r)>>1;
solve(l,mid,d+1);
for(int i=0;i<=mid-l+1;i++) tmp[d][i]=tmp[d+1][i];
solve(mid+1,r,d+1);
for(int i=0;i<=r-mid;i++) tmp2[i]=tmp[d+1][i];
for(bit=0,N=1;N<(r-l+1)<<1;N<<=1,bit++);
for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
for(int i=mid-l+2;i<N;i++) tmp[d][i]=0;
for(int i=r-mid+1;i<N;i++) tmp2[i]=0;
ntt(tmp[d],1),ntt(tmp2,1);
for(int i=0;i<N;i++) tmp[d][i]=1ll*tmp[d][i]*tmp2[i]%mod;
ntt(tmp[d],-1);
for(int i=r-l+2;i<N;i++) tmp[d][i]=0;
}
int main() {
read(n),read(a),read(b);
if(!a||!b||n<a+b-1) return 0*puts("0");
if(n==1) return 0*puts("1");
fac[0]=1;
for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
for(mxn=1;mxn<=(n-1)<<1;mxn<<=1);
prepare();
solve(0,n-2,0);
write(1ll*tmp[0][a+b-2]*fac[a+b-2]%mod*qpow(1ll*fac[a-1]*fac[b-1]%mod,mod-2)%mod);
return 0;
}