\(\text{Solution}\)
这还是 [Lydsy2017省队十连测] 的题
不得不说 \(FFT\) 在字符串匹配中的妙啊!
前面做了道一维的题,现在这是二维的
从题目入手,不考虑可不可达
如果舰队从天而降,考虑其可以落到以那些点为左上角的点
先将地图压成一维,一行接着一行,礁石处为 \(1\) 其余为 \(0\)
抽出包含舰队的最小矩形,按照原来地图行长,从矩阵开头到结尾压成一维,舰队处为 \(1\),其余为 \(0\)
构造匹配函数 \(F_i\) 表示大矩形以 \(i\) 为左上角能否匹配,\(F_i =\sum_j A_{i+j}\cdot B_j\)
则 \(F_i = 0\) 是可行,将 \(A\) 翻转,\(F_i=\sum_j A_{n\cdot m-1-i-j}\cdot B_j\)
这是一个多项式卷积形式,\(NTT\) 即可
于是我们得到了若干个可放入舰队的左上角
我们还要知道那些左上角可达,从小矩阵左上角开始 \(BFS\) 即可
然而我们仍需要那些空地可以被舰队掠过
考虑将可达左上角即为 \(1\),其余为 \(0\),\(B\) 中舰队为 \(1\)
那么 \(i\) 位置可达,需要存在一个可达左上角 \(i-j\)
考虑构造函数 \(F_i = \sum_j A_{i-j} \cdot B_{j}\)
当 \(F_i > 0\) 则该位置可达
发现仍是多项式卷积形式,继续 \(NTT\) 即可
\(\text{Code}\)
#include <cstdio>
#include <iostream>
#define RE register
#define IN inline
using namespace std;
typedef long long LL;
const int N = 705, P = 998244353, g = 3;
char str[N][N];
int n, m, vis[N][N], x0 = N, y0 = N, x1 = 0, y1 = 0, rev[N * N * 2];
int fx[4][2] = {{0, -1}, {-1, 0}, {1, 0}, {0, 1}};
LL a[N * N * 2], b[N * N * 2];
struct Point{int x, y;}Q[N * N];
IN int Get(int x, int y){return (x - 1) * m + y - 1;}
IN int get(int x, int y){return (x - x0) * m + y - y0;}
IN int fpow(LL x, int y){LL s = 1; for(; y; y >>= 1, x = x * x % P) if (y & 1) s = s * x % P; return s;}
IN void NTT(LL *a, int n, int inv)
{
if (n == 1) return;
for(RE int i = 0; i < n; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
for(RE int mid = 1; mid < n; mid <<= 1)
{
int I = fpow(g, (P - 1) / (mid << 1));
if (inv == -1) I = fpow(I, P - 2);
for(RE int i = 0; i < n; i += (mid << 1))
{
LL W = 1;
for(RE int j = 0, x, y; j < mid; j++, W = W * I % P)
x = a[i + j], y = W * a[i + j + mid] % P,
a[i + j] = (x + y) % P, a[i + j + mid] = (x - y + P) % P;
}
}
}
void BFS()
{
int head = 0, tail = 1; Q[1] = Point{x0, y0}, vis[x0][y0] = 0;
while (head < tail)
{
Point z = Q[++head]; a[Get(z.x, z.y)] = 1;
for(RE int k = 0; k < 4; k++)
{
int x = z.x + fx[k][0], y = z.y + fx[k][1];
if (x > 0 && x <= n && y > 0 && y <= m && vis[x][y]) vis[x][y] = 0, Q[++tail] = Point{x, y};
}
}
}
int main()
{
freopen("sailing.in", "r", stdin), freopen("sailing.out", "w", stdout);
scanf("%d%d", &n, &m);
for(RE int i = 1; i <= n; i++) scanf("%s", str[i] + 1);
for(RE int i = 1; i <= n; i++)
for(RE int j = 1; j <= m; j++)
if (str[i][j] == 'o') x0 = min(x0, i), y0 = min(y0, j), x1 = max(x1, i), y1 = max(y1, j);
else if (str[i][j] == '#') a[n * m - 1 - Get(i, j)] = 1;
for(RE int i = 1; i <= n; i++)
for(RE int j = 1; j <= m; j++) if (str[i][j] == 'o') b[get(i, j)] = 1;
int lim = 1; while (lim < n * m) lim <<= 1; int inv = fpow(lim, P - 2);
int bit = 0; while ((1 << bit) < lim) bit++;
for(RE int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
NTT(a, lim, 1), NTT(b, lim, 1);
for(RE int i = 0; i < lim; i++) a[i] = a[i] * b[i] % P;
NTT(a, lim, -1); for(RE int i = 0; i < lim; i++) a[i] = a[i] * inv % P;
for(RE int i = 1; i <= n + x0 - x1; i++)
for(RE int j = 1; j <= m + y0 - y1; j++)
if (!a[n * m - 1 - Get(i, j)]) vis[i][j] = 1;
for(RE int i = 0; i < lim; i++) a[i] = 0;
BFS(), NTT(a, lim, 1);
for(RE int i = 0; i < lim; i++) a[i] = a[i] * b[i] % P;
NTT(a, lim, -1); for(RE int i = 0; i < lim; i++) a[i] = a[i] * inv % P;
int ans = 0;
for(RE int i = 0; i < n * m; i++) if (a[i] > 0) ans++;
printf("%d\n", ans);
}