题意
一个圆环,(n) 个点,编号分别为:(0,1,2,dots ,n-1)。对于环上每个点 (i) 定义它的顺时针方向的下一个点的是 ((i+1) mod n) 。现在进行 (k) 轮调整,每一轮,每个人都会独立地以 (p_1) 的概率向顺时针走一步,以 (p_2) 的概率逆时针走一步,以 (p_3) 的概率留在原地。每个点的人数为:(x_0,x_1,dots ,x_{n-1}) 。求 (k) 轮结束后,每个点上的人数的期望是多少。
(3≤n≤500,0≤k≤10^{18},x_i≤10^6,1 leq a,b,c leq 500)
题目链接:https://ac.nowcoder.com/acm/contest/7079/D
分析
首先可以发现每个位置本质是一样的,所以如果我们想求第 (i) 个位置开始的粉丝经过 (T) 时间到达第 (j) 个位置的粉丝的期望人数可以整体平移到起点为 (0) 的情况,所以我们只需要处理起点为 (0) 的情况就好了。
设 (f[i][j]) 表示在 (i) 时刻一个粉丝从 (0) 位置到 (j) 位置的概率,有:
[f[i][j]=f[i-1][j-1]*p_1+f[i-1][j]*p_3+f[i-1][j+1]*p_2
]
因此可以构造矩阵求解。
例如, (n=5,k=1) 时的转移矩阵:
[left[
egin{matrix}
p_3 p_2 0 0 p_1 \
p_1 p_3 p_2 0 0 \
0 p_1 p_3 p_2 0 \
0 0 p_1 p_3 p_2 \
p_2 0 0 p_1 p_3 \
end{matrix}
ight]*
left[
egin{matrix}
x_0\
x_1\
x_2\
x_3\
x_4
end{matrix}
ight]=
left[
egin{matrix}
X_0\
X_1\
X_2\
X_3\
X_4
end{matrix}
ight]
]
直接用矩阵快速幂的时间复杂度为:(O(n^3logk)),肯定会超时。
可以发现转移矩阵是循环矩阵,进行优化,使得复杂度为:(O(n^2logk))。
循环矩阵
在线性代数中,循环矩阵是一种特殊形式的 Toeplitz 矩阵,它的行向量的每个元素都是前一个行向量各元素依次右移一个位置得到的结果。
性质
- (a+b) 是一个循环矩阵
- (a imes b) 是一个循环矩阵
因此,对于一个循环矩阵,只需要存储矩阵的第一行即可,其它的都可以借助其右移推出。
代码
#include <bits/stdc++.h>
//循环矩阵乘法
using namespace std;
typedef long long ll;
const int N=510;
const int maxn=500;
const int mod=998244353;
int x[N],ans[N],d[N],f[N],n;
ll power(ll a,ll b)
{
ll res=1;
a%=mod;
while(b)
{
if(b&1) res=res*a%mod;
a=a*a%mod;
b>>=1;
}
return res;
}
void mul(int a[],int b[],int c[])
{
for(int i=0;i<n;i++) c[i]=0;
for(int i=0;i<n;i++)//得到矩阵的第1行第1列
{
for(int j=0;j<n;j++)//利用循环的特点,枚举第1列
c[(i+j)%n]=(c[(i+j)%n]+1LL*a[i]*b[j]%mod)%mod;
}
}
void mpower(int a[],ll b)
{
int tmp[N];
while(b)
{
if(b&1)
{
mul(f,a,tmp);
for(int i=0;i<n;i++) f[i]=tmp[i];
}
mul(a,a,tmp);
for(int i=0;i<n;i++) a[i]=tmp[i];
b>>=1;
}
}
int main()
{
int a,b,c,s;
ll k;
scanf("%d%lld",&n,&k);
scanf("%d%d%d",&a,&b,&c);
s=a+b+c;
s=power(s,mod-2);
ll p1=1LL*a*s%mod,p2=1LL*b*s%mod,p3=1LL*c*s%mod;
d[0]=p3,d[1]=p2,d[n-1]=p1,f[0]=1;
mpower(d,k);//f是最终转移矩阵的第1行
for(int i=0;i<n;i++)
scanf("%d",&x[i]);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
ans[i]=(ans[i]+1LL*f[(n-i+j)%n]*x[j]%mod)%mod;
//ans[i]=(ans[i]+1LL*f[(i+j)%n]*x[j]%mod)%mod;不是向左,是向右
}
for(int i=0;i<n;i++)
printf("%d%c",ans[i],i==n-1?'
':' ');
return 0;
}