将序列对应到笛卡尔树,发现每棵笛卡尔树只对应一种合法序列。因为在笛卡尔树上往左走其对应的数至少减 (1),往右走不一定减 (1),所以这棵笛卡尔树从根节点往左走的次数要 (leqslant m),题目就转化为了统计有多少棵 (n) 个节点的合法笛卡尔树。
笛卡尔树是二叉树,因为二叉树和括号序列都可以用卡特兰数计数,所以笛卡尔树能转化为括号序列。考虑中序遍历,每往左走就加入一个左括号,回溯回来时加入一个右括号,往右走时不操作,到叶子节点时加入一对完整括号,那么其就对应到了一个 (2n) 的括号序列。
设 (s_i) 为位置 (i) 之前左括号个数减右括号个数,得其需要满足 (forall i in [1,2n],0 leqslant s_i leqslant m)。将其进一步转化为网格图上路径计数,即从 ((0,0)) 走到 ((n,n)),只能向上向右走,且不能碰到直线 (A:y=x+1,B:y=x-m-1) 的方案数。
考虑翻折法来容斥计数,将连续碰到一条直线看作只碰到一次,如 (A A A B B A) 看作 (A B A),也就是只考虑第一次碰到直线的贡献。不合法方案形如:
考虑要减去以 (A) 开头的方案数,就减去以 (A,A B) 结尾的方案数,加上以 (B A,B A B) 结尾的方案数,一直这样下去,直到方案数为 (0)。这一过程可以看作将终点 ((x,y)) 沿 (A) 翻折,减去从 ((0,0)) 到终点不受限制的方案数,再沿 (B) 翻折,加上从 ((0,0)) 到终点不受限制的方案数,直到终点 ((x,y)) 超出边界。减去以 (B) 开头的方案数同理。
(m) 为两直线间的距离,得复杂度为 (O(frac{n}{m}))。
还有另一种做法,设 (f_{i,j}) 为 (i) 个节点往左走的次数 (leqslant j) 的笛卡尔树个数,得:
其生成函数为 (F_j(x)=sumlimits_{i geqslant 0}f_{i,j}x^i),得:
设 (F_j(x)=frac{A_j(x)}{B_j(x)}),得:
得 (A_j(x)=B_{j-1}(x),B_{j}(x)=B_{j-1}(x)-xA_{j-1}(x)),可以用矩阵快速幂求出 (A_m(x),B_m(x)),这里先用点值表示后再进行快速幂即可。
复杂度为 (O(n log n)),明显没有第一种方法优秀。
#include<bits/stdc++.h>
#define maxn 400010
#define all 400000
#define p 998244353
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,m,x,y;
ll ans;
ll fac[maxn],ifac[maxn];
ll inv(ll x)
{
ll v=1,y=p-2;
while(y)
{
if(y&1) v=v*x%p;
x=x*x%p,y>>=1;
}
return v;
}
ll C(int n,int m)
{
return (n<m||n<0||m<0)?0:fac[n]*ifac[m]%p*ifac[n-m]%p;
}
void c1(int &x,int &y)
{
swap(x,y),x--,y++;
}
void c2(int &x,int &y)
{
swap(x,y),x+=m+1,y-=m+1;
}
void init()
{
fac[0]=ifac[0]=1;
for(int i=1;i<=all;++i) fac[i]=fac[i-1]*i%p;
ifac[all]=inv(fac[all]);
for(int i=all-1;i;--i) ifac[i]=ifac[i+1]*(i+1)%p;
}
int main()
{
init(),read(n),read(m),ans=C(2*n,n);
if(n<m)
{
puts("0");
return 0;
}
x=y=n;
while(x>=0&&y>=0) c1(x,y),ans=(ans+p-C(x+y,y))%p,c2(x,y),ans=(ans+C(x+y,y))%p;
x=y=n;
while(x>=0&&y>=0) c2(x,y),ans=(ans+p-C(x+y,y))%p,c1(x,y),ans=(ans+C(x+y,y))%p;
printf("%lld",ans);
return 0;
}