题目描述
有一个长为(n),高为1001的网格,每个格子有(p)的概率为1,((1-p))的概率0,定义一个网格的价值为极大的全一矩形,且这个矩形的底要贴着网格的底,求这个网格的价值为(K)的概率。
题解
我们可以考虑设一个(dp)。
我们定义每一列的高度为这一列最高的位置满足这个位置及以下的位置都为1。
设(dp[i][j])表示已经做到了前(i)列,此时最低的位置为(j)并且此时的最大价值不超过(K)的概率。
可以看出这是一个前缀和的形式,我们需要在最外面用(K)的答案减去(K-1)的答案来得到最终的答案。
那么我们的(dp)转移可以枚举最靠前的一列满足(j+1)处是0的列。
那么转移为:
(dp[i][j]=[i imes jleq K](1-p)p^jsum_{x=1}^iBigl( (sum_{kgeq j+1}dp[x-1][k]) imes(sum_{lgeq j}dp[i-x][l])Bigr))
我们要求的答案为(dp[n][0])。
后面的两个求和部分可以发现是一个后缀和形式,考虑用后缀和优化转移。
比如说我们已经算完了等号右边的部分,我们可以把它加到(dp[i][l](lleq j))就可以了。
前面枚举(i),然后枚举(K/i),再去枚举(i),总的复杂度为(K^2)。
我们发现(n)非常大,所以我们需要发现一些别的性质。
发当(ngeq k)时,第二维只能为0,所以我们令(g[n]=dp[n][0])。
[g[n]=(1-p)sum_{i=1}^{ileq K+1}dp[i-1][1] imes g[n-i]
]
[g[n]=sum_{i=1}^{ileq K+1}(dp[i-1][1] imes (1-p)) imes g[n-i]
]
这样我们把它转化为一个(K+1)次的常系数齐次线性递推。
可以用矩阵乘法优化为(O(K^3logn)),期望得分90,使用多项式取模可以优化至(O(K^2logn)sim(KlogKlogn))期望得分100。
注意特判(K=0)。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#define N 1009
using namespace std;
typedef long long ll;
const int mod=998244353;
int pos,k,n;
ll tmp[N<<1],p[N],dp[N][N],P,y,now[N],num[N],ans[N];
inline ll rd(){
ll x=0;char c=getchar();bool f=0;
while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
return f?-x:x;
}
inline ll power(ll x,ll y){
ll ans=1;
while(y){if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;}
return ans;
}
inline void MOD(ll &x){while(x>=mod)x-=mod;}
inline void mul(ll *a,ll *b,ll *c){
for(int i=0;i<=2*pos;++i)tmp[i]=0;
for(int i=0;i<pos;++i)
for(int j=0;j<pos;++j)(tmp[i+j]+=a[i]*b[j]%mod)%=mod;
for(int i=2*pos-2;i>=pos;--i){
for(int j=0;j<pos;++j)tmp[i-pos+j]=(tmp[i-pos+j]-tmp[i]*p[j]%mod+mod)%mod;
tmp[i]=0;
}
for(int i=0;i<pos;++i)c[i]=tmp[i];
}
inline ll calc(int k){
if(!k){return power(1-P+mod,n);}
memset(dp,0,sizeof(dp));
memset(num,0,sizeof(num));
memset(ans,0,sizeof(ans));
for(int i=0;i<=k+1;++i)dp[0][i]=1;
for(int i=1;i<=k;++i)
for(int j=0;i*j<=k;++j){
ll sum=0;
for(int x=1;x<=i;++x)(sum=sum+dp[x-1][j+1]*dp[i-x][j]%mod)%=mod;
sum=sum*power(P,j)%mod*(1+mod-P)%mod;
for(int x=0;x<=j;++x)(dp[i][x]+=sum)%=mod;
}
k++;pos=k;
for(int i=1;i<=k;++i)now[i]=dp[i-1][1]*(mod+1-P)%mod;
for(int i=1;i<=k;++i)p[k-i]=mod-now[i];
num[1]=1;ans[0]=1;
int nn=n;
while(nn){
if(nn&1)mul(ans,num,ans);
mul(num,num,num);nn>>=1;
}
ll res=0;
for(int i=0;i<k;++i)MOD(res+=ans[i]*dp[i][0]%mod);
return res;
}
int main(){
n=rd();k=rd();P=rd();y=rd();P=P*power(y,mod-2)%mod;
printf("%lld",(calc(k)-calc(k-1)+mod)%mod);
return 0;
}