题目大意
在(n*m)的网格上,一只马在点((1,1)),点((i,j))可以跳到((i-1,j+k))或((i,j+k))或((i+1,j+k)),其中(k)是一个奇数,求跳到((n,m))的方案数。
解析
设:
(f_{i,j})表示跳到((j,i))的方案数(为了方便我换了一下(i,j)的顺序,相当于按列为阶段转移)
(a_{i,j}=sum_{k=0}f_{i,j-2k})
(b_{i,j}=sum_{k=0}f_{i,j-2k-1})
得到三者之间关系是:
(f_{i,j}=a_{i-1,j}+a_{i-1,j-1}+a_{i-1,j+1})
(a_{i,j}=b_{i-1,j}+f_{i,j})
(b_{i,j}=a_{i-1,j})
仅根据这三条式子,就能够用(O(nm))的时间复杂度求出(f_{i,j})了。
但(mleq 10^9),还需优化。
原来是三个状态的转移,现在我们变一下式子:
(a_{i,j}=a_{i-2,j}+a_{i-1,j}+a_{i-1,j-1}+a_{i-1,j+1})
现在只剩(a)一个状态的转移了,最后求(f_{i,j}=a_{i,j}-a_{i-2,j})即可。
由于(m)很大,考虑使用矩阵乘法。
转移矩阵的构造方法很巧妙,我们把(a_{i-1,1 sim n})还有(a_{i-2,1 sim n})放在初始矩阵的第一行,其它位置全部填(0)。
例如(n=3)时,初始矩阵为:
(egin{Bmatrix} a_{i-1,1} & a_{i-1,2} & a_{i-1,3} & a_{i-2,1} & a_{i-2,2} & a_{i-2,3} \ 0 & 0 & 0 & 0 & 0 & 0 \ 0 & 0 & 0 & 0 & 0 & 0 \ 0 & 0 & 0 & 0 & 0 & 0 \ 0 & 0 & 0 & 0 & 0 & 0 \ 0 & 0 & 0 & 0 & 0 & 0 end{Bmatrix}
quad)
转移矩阵为:
(egin{Bmatrix} 1 & 1 & 0 & 1 & 0 & 0 \ 1 & 1 & 1 & 0 & 1 & 0 \ 0 & 1 & 1 & 0 & 0 & 1 \ 1 & 0 & 0 & 0 & 0 & 0 \ 0 & 1 & 0 & 0 & 0 & 0 \ 0 & 0 & 1 & 0 & 0 & 0 end{Bmatrix}
quad)
这样时间复杂度降为(O(n^3logm)),问题解决了。
Code
#include <cstdio>
#include <cstring>
const int N = 57, P = 30011;
int max(int a, int b) { return a > b ? a : b; }
int min(int a, int b) { return a < b ? a : b; }
int n, m;
struct matrix
{
int num[N * 2][N * 2];
matrix operator*(matrix a)
{
matrix c; memset(c.num, 0, sizeof(c.num));
for (int i = 0; i < 2 * n; i++)
for (int j = 0; j < 2 * n; j++)
for (int k = 0; k < 2 * n; k++)
c.num[i][j] = (c.num[i][j] + num[i][k] * a.num[k][j] % P) % P;
return c;
}
} bas, mov, ret;
int getit(int m, int n)
{
if (m <= 0) return 0;
memset(bas.num, 0, sizeof(bas.num));
memset(mov.num, 0, sizeof(mov.num));
memset(ret.num, 0, sizeof(ret.num));
for (int j = 0; j < n; j++) for (int i = max(j - 1, 0); i <= min(j + 1, n - 1); i++) mov.num[i][j] = 1;
for (int i = n; i < 2 * n; i++) mov.num[i][i - n] = 1;
for (int j = n; j < 2 * n; j++) mov.num[j - n][j] = 1;
bas.num[0][0] = 1;
for (int i = 0; i < 2 * n; i++) ret.num[i][i] = 1;
m--;
while (m)
{
if (m & 1) ret = ret * mov;
mov = mov * mov, m >>= 1;
}
bas = bas * ret;
return bas.num[0][n - 1];
}
int main()
{
scanf("%d%d", &n, &m);
printf("%d
", (getit(m, n) - getit(m - 2, n) + P) % P);
return 0;
}