令$f_{i,j}$表示序列${x_{i},x_{i+1},...,x_{n+1}}$的个数,满足$x_{i}=j$且$forall ile kle n,a_{k}x_{k}le x_{k+1}$
关于转移方程,显然有$f_{i,j}=egin{cases}sum_{a_{i}jle k}f_{i+1,k}&(jle m)\0&(j>m)end{cases}$(其中$jin Z^{+}$),初始状态为$f_{n+1,j}=[jle m]$
令$R_{i}=lfloorfrac{m}{prod_{j=i}^{n}a_{j}} floor$,有$forall j>R_{i},f_{i,j}=0$,再定义$deg(F)$表示多项式$F$最高非0项的次数
$forall 1le ile n+1$,存在多项式$F_{i}$,满足$deg(F_{i})le n-i+1$且$forall jin [1,R_{i}],F_{i}(j)=f_{i,j}$
类似地,还存在多项式$S_{i}$,满足$deg(S_{i})le n-i+2$且$forall jin [0,R_{i}],S_{i}(j)=sum_{k=1}^{j}f_{i,k}$
(关于这两个结论的正确性,从后往前归纳即可,具体证明略)
根据这两个结论,当我们知道所有$f_{i,j}$的值后(其中$jin [1,min(n-i+2,R_{i})]$),即可通过拉格朗日插值法在$o(n)$的时间内对某个$j$算出$f_{i,j}$和$sum_{k=1}^{j}f_{i,k}$
考虑维护所有$f_{i,j}$(其中$jin [1,min(n-i+2,R_{i})]$),并进行转移,转移即$f_{i,j}=sum_{a_{i}jle k}f_{i+1,k}$
对其差分,即$f_{i,j}=S_{i+1}(R_{i+1})-S_{i+1}(a_{i}j-1)$,根据上述插值,即可在$o(n^{2})$的时间内算出$f_{i,j}$,由于对每一个$i$都有此计算,复杂度为$o(n^{3})$,无法通过
但事实上,并不是所有$j$计算复杂度都是$o(n)$,当$jin [1,min(n-i+2,R_{i})]$复杂度显然为$o(1)$,因此$a_{i}=1$时仅需要插一个$S_{i+1}(R_{i+1})$,复杂度为$o(n)$,而$a_{i}ge 2$一共只有$o(log m)$个,总复杂度即$o(n^{2}log m)$
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 1005 4 #define ll long long 5 #define mod 998244353 6 int n,a[N],fac_pre[N],fac_suf[N],inv[N],f[N][N],sum[N][N]; 7 ll m,r[N]; 8 int calc_S(int i,ll x){ 9 x=min(x,r[i])%mod; 10 if (x<=n-i+2)return sum[i][x]; 11 fac_pre[0]=x; 12 for(int j=1;j<=n-i+2;j++)fac_pre[j]=1LL*fac_pre[j-1]*(x+mod-j)%mod; 13 fac_suf[n-i+3]=1; 14 for(int j=n-i+2;j>=0;j--)fac_suf[j]=1LL*fac_suf[j+1]*(x+mod-j)%mod; 15 int ans=0; 16 for(int j=0;j<=n-i+2;j++){ 17 int s=1LL*sum[i][j]*inv[j]%mod*inv[n-i+2-j]%mod*fac_suf[j+1]%mod; 18 if (j)s=1LL*s*fac_pre[j-1]%mod; 19 if ((n-i-j)&1)s=mod-s; 20 ans=(ans+s)%mod; 21 } 22 return ans; 23 } 24 int main(){ 25 inv[0]=inv[1]=1; 26 for(int i=2;i<N;i++)inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod; 27 for(int i=1;i<N;i++)inv[i]=1LL*inv[i-1]*inv[i]%mod; 28 scanf("%d%lld",&n,&m); 29 for(int i=1;i<=n;i++)scanf("%d",&a[i]); 30 r[n+1]=m; 31 for(int i=n;i;i--)r[i]=r[i+1]/a[i]; 32 f[n+1][1]=sum[n+1][1]=1; 33 for(int i=n;i;i--){ 34 int s=calc_S(i+1,r[i+1]); 35 for(int j=1;j<=min(n-i+2LL,r[i]);j++){ 36 f[i][j]=(s+mod-calc_S(i+1,1LL*a[i]*j-1))%mod; 37 sum[i][j]=(sum[i][j-1]+f[i][j])%mod; 38 } 39 } 40 printf("%d",calc_S(1,r[1])); 41 }