洛谷10月月赛 2 t2 深海少女与胖头鱼
参考资料:洛谷10月赛2讲评ppt;
本篇题解考完那天就开始写,断断续续写到今天才写完
本题作为基础的期望dp题,用来学习期望dp还是很不错的
(说是期望dp,不如说是期望递推?)
另外,本题用到了模意义下的除法变乘法,这也是一个基础但重要的概念
1 算法分析
part 1
我们先来考虑(m=0)的情况,也就是说所有的胖头鱼都带有圣盾
(f[i])表示有(i)条圣盾胖头鱼时的期望伤害次数,我们先击破一条胖头鱼的圣盾
变成(i-1)条圣盾胖头鱼,考虑下一步操作,当我们攻击那条不带圣盾的胖头鱼时
该胖头鱼死亡,发生的概率为(frac{1}{i})转移到(f[i-1])的状态
当我们攻击带圣盾的胖头鱼时只不过是将该胖头鱼的圣盾换到了那条没带圣盾的胖头鱼身上
该操作的概率为(frac{i-1}{i})
将上面的思路总结一下,我们就得到了下面的式子
我们对它化简一下,就有了
化简过程就是简单的去分母,移项
根据这个式子,利用递推就可以get到10point
代码如下
#include<iostream>
#define ll long long
#define mod 998244353
using namespace std;
int n,m;
ll f[1000000];
int main(){
cin>>n>>m;
for(int i=1;i<=n;i++){
f[i]=(f[i-1]+i+1)%mod;
}
cout<<f[n]%mod;
return 0;
}
part 2
我们再观察一下这个式子,这不就是等差数列吗?
直接(O(1))求出答案 30point get
代码如下
#include<iostream>
#include<cstring>
#include<cstdio>
#define ll long long
using namespace std;
ll n,m;
ll x,y;
void exgcd(ll a,ll b,ll &x,ll &y){
if(!b){
x=1;
y=0;
}
else{
exgcd(b,a%b,y,x);
y-=a/b*x;
}
}
int main(){
freopen("a.in","r",stdin);
cin>>n>>m;
ll p=998244353L;
exgcd(2,p,x,y);
x=(x%p+p)%p;
ll t=(n%p)*((n+3LL)%p)%p;
cout<<(t*x)%p;
return 0;
}
我们看到了个奇怪的东西exgcd,为什么我们要使用扩欧呢?
我们在这里用到了等差数列求和公式,里面出现了除法,我们要明白,
在模意义下,除法是不能直接除的,我们需要乘以除数的逆元,本部分用了2在%p意义下的逆
元,我们用扩欧求逆元
part 3
我们处理完了(m=0)的所有情况,我们来考虑(m!=0)的所有情况,
设(g[i])为有n个圣盾胖头鱼,i个无圣盾胖头鱼的期望伤害次数
我们考虑执行攻击操作,若攻击到了无圣盾胖头鱼,该鱼死亡,进入到(g[i-1])状态
该操作的概率为(frac{i}{n+i})
假设我们该操作攻击到了带圣盾胖头鱼,该胖头鱼圣盾破裂
但其余胖头鱼全部都具有了圣盾
进入了(f[n+i]-1)状态,为什么要-1?因为该状态和(f[i+n])需要打破一个圣盾才能进入到目前状态
但不要忘了,我们最开始攻击胖头鱼是有一次操作的,需要+1,所以+1就和-1抵消掉了
该操作的可能性为(frac{n}{i+n})
归纳上述式子,得到了
(f)数组考虑用公式或者矩阵乘法进行优化
std写了矩阵乘法,我在此写了公式法
因为需要的逆元比较多,在此考虑费马小定理求逆元,具体细节可以自行百度
本题求逆元对时间复杂度的影响较大,我们可以提前算一些逆元,来降低时间复杂度(最开始就因为
这个,一直tle)
时间复杂度为
(O(m log p))
(p)为模数
代码如下
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
#define maxn 2000000
using namespace std;
ll m,n;
ll mod =998244353;
ll power(ll x,ll t)
{
if(x==0) return 0;
x%=mod;
ll b=1;
while(t)
{
if(t&1) b=b*x%mod;
x=x*x%mod; t>>=1;
}
return b;
}
ll f(ll x){
ll t=(x%mod)*((x+3LL)%mod)%mod;
return (t*power(2,mod-2))%mod;
}
ll g[maxn];
ll ft[maxn];
int main(){
cin>>n>>m;
ll t=(n%mod)*((n+3LL)%mod)%mod;
g[0]=(t*power(2,mod-2))%mod;
n%=mod;
ll inv=power(2,mod-2);
for(int i=1;i<=m;i++){
int x=n+i;
ll t=(x%mod)*((x+3LL)%mod)%mod;
ft[i]=(t*inv)%mod;
}
for(int i=1;i<=m;i++){
inv=power(n+i,mod-2);
g[i]=((i*inv%mod*(g[i-1]+1)%mod)%mod+(n*inv%mod*ft[i]%mod)%mod)%mod;
}
cout<<g[m];
return 0;
}
完结撒花!!!