题目大意
给定两个仅由小写字母组成的字符串 (x) 和 (y)。
如果一个序列仅包含 (|x|) 个 (0) 和 (|y|) 个 (1),则称这个序列为合并序列。
字符串 (z) 初始为空,按如下规则由合并序列 (a) 生成:
- 如果 (a_i=0),则把 (x) 开头的一个字符加到 (z) 的末尾;
- 如果 (a_i=1),则把 (y) 开头的一个字符加到 (z) 的末尾。
两个合并序列 (a) 和 (b) 被认为是不同的,如果存在某个 (i),使得 (a_i eq b_i)。
若一个字符串任意两个相邻位置上的字符都不同,则我们称该字符串是混乱的。
定义 (f(l_1,r_1,l_2,r_2)) 表示能从 (x) 的子串 (x[l_1,r_1]) 和 (y) 的子串 (y[l_2,r_2]) 生成混乱的字符串的不同的合并序列的数量,要求子串非空。
求 (sum limits_{1 le l_1 le r_1 le |x| , 1 le l_2 le r_2 le |y|} f(l_1, r_1, l_2, r_2)),答案对 (998244353) 取模。
((1leq |x|,|y|leq 1000))
题解
直接枚举子串然后算有多少个不同的合并序列肯定不好做,所以我们考虑dp。
设 (dp[i][j][0]) 表示以第 (i) 个位置结尾的 (x) 的子串和以第 (j) 个位置结尾的 (y) 的子串进行合并,并且合并后的混乱的字符串以 (x) 的第 (i) 个位置结尾,不同的合并序列的数量。
设 (dp[i][j][1]) 表示以第 (i) 个位置结尾的 (x) 的子串和以第 (j) 个位置结尾的 (y) 的子串进行合并,并且合并后的混乱的字符串以 (y) 的第 (j) 个位置结尾,不同的合并序列的数量。
那么我们只需枚举倒数第二个位置上是什么,同时要满足它和最后一个位置上的字符不同。
若 (x[i-1]
eq x[i]),则 (dp[i][j][0]+=dp[i-1][j][0])。
若 (y[j]
eq x[i]),则 (dp[i][j][0]+=dp[i-1][j][1])。
若 (x[i]
eq y[j]),则 (dp[i][j][1]+=dp[i][j-1][0])。
若 (y[j-1]
eq y[j]),则 (dp[i][j][1]+=dp[i][j-1][1])。
注意到以上转移必须满足上一个状态中 (x) 和 (y) 的两个子串都非空。但我们可以只取一个字符作为一个子串合并到最后,所以上一个状态的 (x) 或 (y) 是可以为空的,但是空串我们又不计入答案。所以我们维护 (dpx[i]) 表示 (x) 中有多少个以第 (i) 个位置结尾的混乱的子串,(dpy[j]) 表示 (y) 中有多少个以第 (j) 个位置结尾的混乱的子串,则有:
若 (x[i]
eq x[i-1]),则 (dpx[i]=dpx[i-1]+1),否则 (dpx[i]=1)。
若 (y[j]
eq y[j-1]),则 (dpy[j]=dpy[j-1]+1),否则 (dpy[j]=1)。
若 (x[i]
eq y[j]), (dp[i][j][0]+=dpy[j])。
若 (x[i]
eq y[j]), (dp[i][j][1]+=dpx[i])。
时间复杂度 (O(|x||y|))。
Code
#include <bits/stdc++.h>
using namespace std;
#define RG register int
#define LL long long
const LL MOD = 998244353;
char x[1005], y[1005];
LL dp[1001][1001][2], dpx[1001], dpy[1001];
int n, m;
int main() {
scanf("%s", x + 1);
scanf("%s", y + 1);
n = strlen(x + 1);
m = strlen(y + 1);
LL ans = 0;
for (int i = 1;i <= m;++i) {
dpy[i] = 1;
if (y[i - 1] != y[i]) dpy[i] = (dpy[i] + dpy[i - 1]) % MOD;
}
for (int i = 1;i <= n;++i) {
dpx[i] = 1;
if (x[i - 1] != x[i]) dpx[i] = (dpx[i] + dpx[i - 1]) % MOD;
for (int j = 1;j <= m;++j) {
if (x[i] != y[j]) dp[i][j][0] = (dp[i][j][0] + dpy[j]) % MOD;
if (x[i] != y[j]) dp[i][j][1] = (dp[i][j][1] + dpx[i]) % MOD;
if (x[i - 1] != x[i]) dp[i][j][0] = (dp[i][j][0] + dp[i - 1][j][0]) % MOD;
if (y[j] != x[i]) dp[i][j][0] = (dp[i][j][0] + dp[i - 1][j][1]) % MOD;
if (x[i] != y[j]) dp[i][j][1] = (dp[i][j][1] + dp[i][j - 1][0]) % MOD;
if (y[j - 1] != y[j]) dp[i][j][1] = (dp[i][j][1] + dp[i][j - 1][1]) % MOD;
ans = (ans + dp[i][j][0] + dp[i][j][1]) % MOD;
}
}
printf("%I64d
", ans);
return 0;
}