-
看到 (mle 12) 和 (cle 6) ,容易想到状压 DP
-
考虑转化成 (3^{nm}) 减去不合法的方案数,轮廓线 DP :(f[i][j][S][k][h]) 表示 DP 到了第 (i) 行第 (j) 格,轮廓线上 (m) 格的状态为 (S) 的方案数,(k) 表示最大的 (x) 使得模板矩阵第一行的前 (x) 个字符和目标矩阵第 (i-1) 行的后 (x) 个字符相同,(h) 表示最大的 (x) 使得模板矩阵第二行的前 (x) 个字符和目标矩阵第 (i) 行的后 (x) 个字符相同
-
于是我们每次枚举下一个格子填什么字符,(k) 和 (h) 利用 KMP 进行转移,任何时候保证 (k) 和 (h) 不全为 (c) 即可
-
但是这样做的复杂度为 (O(qnmc^23^m)) ,无法通过此题
-
我们发现复杂度主要消耗在 (3^m) 上,考虑如何才能不储存前面每个字符的状态
-
考虑仍然逐行转移 (f[i][S]) 表示前 (i) 行状态为 (S) ,但这里的 (S) 是一个集合,表示如果第 (i) 行以第 (x) 个格子为结尾的长度为 (c) 的子串与模板矩阵第一行相等,那么 (xin S) ,否则 (x otin S)
-
然后每一列内逐格转移:(g[j][S][k][h]) 表示前 (j) 个格子,轮廓线上一排格子状态为 (S) ,并且前 (j) 个格子的后缀与模板矩阵第一行和第二行前缀匹配的最大长度分别为 (k) 和 (h) 的方案数
-
边界 (g[0][S][0][0]=f[i-1][S])
-
还是每次枚举下一个格子填什么,注意若 (g[j][S][k][h]) 满足 (j+1in S) 且 (h) 下一步转移到了 (c) 则此步不能转移,若 (k) 下一步转移到了 (c) 则转移后的 (S) 集合中包含 (j+1) ,否则不包含
-
由于 (S) 中的任意元素都在 ([c,m]) 内,故复杂度 (O(qnmc^22^{m-c})) ,可以通过此题
Code
#include <bits/stdc++.h>
const int N = 105, E = 14, C = (1 << 12) + 5, R = 8, rqy = 1e9 + 7;
const char ch[] = {'W', 'B', 'X'};
int n, m, c, q, Cm, f[N][C], g[E][C][R][R], nxt[2][R], tr[2][R][3], sum = 1;
char s[2][R];
void KMP(int n, char *s, int *nxt)
{
nxt[1] = 0;
for (int i = 2, j = 0; i <= n; i++)
{
while (j && s[i] != s[j + 1]) j = nxt[j];
if (s[i] == s[j + 1]) j++;
nxt[i] = j;
}
}
inline void add(int &a, const int &b)
{
a += b; if (a >= rqy) a -= rqy;
}
int DP()
{
f[0][0] = 1;
for (int i = 1; i <= n; i++)
{
for (int j = 0; j <= m; j++)
for (int S = 0; S < Cm; S++)
for (int k = 0; k <= c; k++)
for (int h = 0; h <= c; h++)
g[j][S][k][h] = 0;
for (int S = 0; S < Cm; S++)
g[0][S][0][0] = f[i - 1][S];
for (int j = 0; j < m; j++)
for (int S = 0; S < Cm; S++)
for (int k = 0; k <= c; k++)
for (int h = 0; h <= c; h++)
{
if (!g[j][S][k][h]) continue;
for (int w = 0; w < 3; w++)
{
int T = S;
if (j >= c - 1)
{
if (((S >> j - c + 1) & 1) && tr[1][h][w] == c)
continue;
T &= Cm - 1 ^ (1 << j - c + 1);
if (tr[0][k][w] == c) T |= 1 << j - c + 1;
}
add(g[j + 1][T][tr[0][k][w]][tr[1][h][w]], g[j][S][k][h]);
}
}
for (int S = 0; S < Cm; S++)
for (int k = 0; k <= c; k++)
for (int h = 0; h <= c; h++)
add(f[i][S], g[m][S][k][h]);
}
int res = 0;
for (int S = 0; S < Cm; S++) add(res, f[n][S]);
return res;
}
int main()
{
std::cin >> n >> m >> c >> q; Cm = 1 << m - c + 1;
for (int i = 1; i <= n * m; i++) sum = 3ll * sum % rqy;
while (q--)
{
for (int i = 0; i < 2; i++) scanf("%s", s[i] + 1);
KMP(c, s[0], nxt[0]); KMP(c, s[1], nxt[1]);
memset(f, 0, sizeof(f));
for (int T = 0; T < 2; T++)
for (int i = 0; i <= c; i++)
for (int w = 0; w < 3; w++)
{
int j = i;
while (j && s[T][j + 1] != ch[w]) j = nxt[T][j];
tr[T][i][w] = s[T][j + 1] == ch[w] ? j + 1 : j;
}
printf("%d
", (sum - DP() + rqy) % rqy);
}
return 0;
}