CF1093F Vasya and Array 更优秀的做法
摘要
本题有一个经典的 DP + 容斥做法,时间复杂度是 (O(nk))。
本文作者在此基础上,创新性地发掘题目性质,简化 DP 状态,利用数据结构优化 DP,提出了一个时间复杂度 (O(n)) 的做法。显然相比经典做法,是更加优秀的。
本文将分别介绍这两种做法。
题目大意
给出一段长度为 (n) 的整数序列,一个正整数 (k) ,一个正整数 ( ext{len})。序列中的所有数要么在 ([1,k]) 之间,要么等于 (-1)。
我们称一个序列是好的,当且仅当不存在 ( ext{len}) 个连续的相同的数字。
你可以将每个 (-1),替换成任意一个 ([1,k]) 之间的整数。求有多少种方案,使得最终序列是好的。答案对 (998244353) 取模。
数据范围:(1leq nleq 10^5),(1leq kleq 100),(1leq ext{len}leq n)。
经典做法
设 ( ext{dp}_0(i,j)) 表示考虑了前 (i) 个位置,最后一个位置上填的数是 (j),前 (i) 个位置组成的序列合法的方案数。设 ( ext{sdp}(i)=sum_{j=1}^{k} ext{dp}_0(i,j))。
首先,当 (a_i eq -1) 且 (a_i eq j) 时,( ext{dp}_0(i,j)=0)。
否则,我们暂时令 ( ext{dp}_0(i,j)= ext{sdp}(i - 1))。但是这会把一些不合法的方案算进去。具体来说,这样有可能出现 ([a_{i- ext{len}+1}dots a_i]) 全部相同的情况。这种情况会出现当且仅当如下两个条件都满足:
- (igeq ext{len})。
- ([a_{i- ext{len}+1}dots a_i]) 中每个数都等于 (-1) 或 (j)。
所以,我们还要减去这种情况的数量:( ext{sdp}(i- ext{len})- ext{dp}_0(i- ext{len},j))。其中 ( ext{sdp}(i - ext{len})) 表示 ([a_{i- ext{len}+1}dots a_i]) 全部等于 (j) 时,前 (i - ext{len}) 位的填写方案。不过这些方案中,有一些方案可能在前 (i-1) 位就已经导致不合法了。这些提前不合法的方案本来就没有被算在 ( ext{sdp}(i-1)) 中,所以不需要被减去,它们的数量是 ( ext{dp}_0(i - ext{len},j))。
综上所述,可以得到转移式:
时间复杂度(O(nk))。
更优秀的做法
首先,当 ( ext{len} = 1) 时,答案一定是 (0)。以下只讨论 ( ext{len} > 1) 的情况。
先考虑一种朴素的 DP。设 ( ext{dp}_1(i,j,l)) 表示考虑了前 (i) 个位置,第 (i) 位上填的数是 (j),最后一个 ( eq j) 的位置是 (l),此时使得前 (i) 位组成的序列合法的方案数。
转移时,考虑当前位填了什么:
初始状态为 ( ext{dp}_1(0,0,0) = 1)。答案是 (sum_{j = 1}^{k}sum_{l = n - ext{len}+1}^{n - 1} ext{dp}_1(n,j,l))。
这个朴素 DP 的时间复杂度是 (O(n^2k))。
优化它!设上一位填的数为 (j),当前位填的数为 (x)。我们发现,当 (x) 和 (j) 不同时,(x) 的值具体是什么其实不重要:对所有 (x eq j),它们的转移是一模一样的。这就给了我们简化的空间。
定义 (a_{n+1} = k+1)。定义 ( ext{nxt}_i) 表示位置 (i) 后面第一个 (a_{i'} eq -1) 的位置 (i')。设 ( ext{dp}_2(i,jin{0,1},l)) 表示考虑了前 (i) 个位置,第 (i) 位上填的数是 / 否等于 (a_{ ext{nxt}_i}),前 (i) 位里最后一个填的数与第 (i) 位上不同的位置是 (l),此时使得前 (i) 位组成的序列合法的方案数。
转移分 (a_i) 是否为 (-1) 两种情况。
当 (a_i eq -1) 时,枚举 (l)。则有如下转移:
当 (a_i = -1) 时,显然 ( ext{nxt}_{i} = ext{nxt}_{i-1})。枚举 (l)。我们分别考虑如下情况:
- 第 (i-1) 位填的数与 (a_{ ext{nxt}_{i}}) 不同:
- 第 (i) 位上填的数与第 (i-1) 位上填的数相同。
- 第 (i) 位上填的数与第 (a_{ ext{nxt}_i}) 相同。
- 第 (i) 位上填的数,既不等于第 (i-1) 位上填的数,也不等于 (a_{ ext{nxt}_{i}})。
- 第 (i-1) 位填的数与 (a_{ ext{nxt}_{i}}) 相同:
- 第 (i) 位填的数与 (a_{ ext{nxt}_{i}}) 不同。
- 第 (i) 位填的数与 (a_{ ext{nxt}_{i}}) 相同。
这五种情况分别对应如下转移:
上述式子里默认 (1 < i < ext{nxt}_i leq n)。当 (i = 1) 或 ( ext{nxt}_i = n+1) 时,有一些特殊情况要考虑。为了表述简洁,这里就不细写了。
现在,这个 DP 的时间复杂度是 (O(n^2))。虽然无法 AC,但这是迈向 (O(n)) 做法的关键一步。我将这一 DP 的代码附在了本文末尾:点击跳转。
在这个 DP 里,第 (1) 维的枚举不可避免,第 (2) 维的状态数已经被我们优化到 (O(1))。考虑优化第 (3) 维。
观察转移式,发现从 (i-1) 变成 (i) 时,第 (3) 维的转移,相当于做如下操作:
- 区间求和。对所有 (lin[i - ext{len},i - 2]) 求和。
- 单点加。
- 把一段区间的值覆盖为 (0)。
- 在 (a_i eq -1) 且 (a_i eq a_{ ext{nxt}_i}) 时,需要把 ( ext{dp}_2(i,1)) 中的一段 (l),复制到 ( ext{dp}_2(i,0)) 对应的位置上。
按第二维的 (0,1),维护两棵线段树,通过打懒标记即可实现这四种操作。
时间复杂度 (O(nlog n))。
继续优化,发现区间操作都是假的。
- 操作 (3) 要么是单点清空,要么是全局清空。
- 操作 (1) 和操作 (4),因为其他位置都清空了,所以区间求和、区间复制,其实就是全局求和、全局交换。
所以我们只需要用两个数组来维护。
时间复杂度 (O(n))。
参考代码
最终代码
友情提醒:使用读入、输出优化可以使代码更快,详见本博客公告。
// problem: CF1093F
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 1e5;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }
int n, K, len, a[MAXN + 5];
struct FantasticDataStructure {
int sum;
int arr[MAXN + 5];
int TIM;
int tim[MAXN + 5];
void upd(int p) {
if (tim[p] < TIM) {
tim[p] = TIM;
arr[p] = 0;
}
}
void point_add(int p, int v) {
upd(p);
add(arr[p], v);
add(sum, v);
}
void point_set0(int p) {
upd(p);
sub(sum, arr[p]);
arr[p] = 0;
}
void global_set0() {
TIM++;
sum = 0;
}
int query() {
return sum;
}
FantasticDataStructure() {}
};
FantasticDataStructure S[2];
int id[2];
// DP 转移: a[i] != -1 / a[i] == -1
void trans1(int p, int nxtval) {
int v = S[id[0]].query();
if (a[p] == nxtval) {
S[id[0]].global_set0();
if (p - len >= 0) {
S[id[1]].point_set0(p - len);
}
S[id[1]].point_add(p - 1, v);
} else {
swap(id[0], id[1]);
S[id[1]].global_set0();
if (p - len >= 0) {
S[id[0]].point_set0(p - len);
}
S[id[0]].point_add(p - 1, v);
}
}
void trans2(int p, int flag) {
int v0 = S[id[0]].query();
int v1 = S[id[1]].query();
if (p == 1) {
S[id[0]].global_set0();
} else {
if (p - len >= 0) {
S[id[0]].point_set0(p - len);
}
}
int toadd = 0;
if (flag + (p != 1) + 1 <= K) {
toadd = (ll)v0 * (K - flag - (p != 1)) % MOD;
}
add(toadd, (ll)v1 * (K - 1) % MOD);
S[id[0]].point_add(p - 1, toadd);
if (p - len >= 0) {
S[id[1]].point_set0(p - len);
}
if (flag) {
S[id[1]].point_add(p - 1, v0);
}
}
int main() {
cin >> n >> K >> len;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
if (len == 1) {
cout << 0 << endl;
return 0;
}
id[0] = 0;
id[1] = 1;
S[id[1]].point_add(0, 1);
a[0] = K + 1;
a[n + 1] = K + 2;
for (int i = 1, j = 2; i <= n; ++i) {
ckmax(j, i + 1);
while (a[j] == -1)
++j;
if (a[i] != -1) {
trans1(i, a[j]);
} else {
trans2(i, j != n + 1);
}
}
cout << S[id[0]].query() << endl;
return 0;
}
n^2 DP
为了帮助读者更好地理解题解,这里附上 (O(n^2)) 朴素 DP 的代码。
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 1000;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }
int n, K, len, a[MAXN + 5];
int dp[MAXN + 5][2][MAXN + 5];
void trans1(int p, int nxtval) {
for (int i = max(0, p - len); i <= max(0, p - 2); ++i) {
add(dp[p][a[p] == nxtval][p - 1], dp[p - 1][0][i]);
if (i >= p - len + 1)
add(dp[p][a[p] == nxtval][i], dp[p - 1][1][i]);
}
}
void trans2(int p, int flag) {
for (int i = max(0, p - len); i <= max(0, p - 2); ++i) {
// a[p - 1] 和 nxtval 不同
if (p != 1 && i >= p - len + 1) {
add(dp[p][0][i], dp[p - 1][0][i]); // a[p] 和 a[p - 1] 相同
}
if (flag) {
add(dp[p][1][p - 1], dp[p - 1][0][i]); // a[p] 和 nxtval 相同
}
if (flag + (p != 1) + 1 <= K) {
// a[p] 和 nxtval, a[p - 1] 都不同
add(dp[p][0][p - 1], (ll)dp[p - 1][0][i] * (K - flag - (p != 1)) % MOD);
}
// a[p - 1] 和 nxtval 相同
add(dp[p][0][p - 1], (ll)dp[p - 1][1][i] * (K - 1) % MOD);
if (i >= p - len + 1) {
add(dp[p][1][i], dp[p - 1][1][i]); // a[p] 和 a[p - 1], nxtval 都相同
}
}
}
int main() {
cin >> n >> K >> len;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
dp[0][0][0] = 1;
a[0] = K + 1;
a[n + 1] = K + 2;
for (int i = 1, j = 2; i <= n; ++i) {
ckmax(j, i + 1);
while (a[j] == -1)
++j;
if (a[i] != -1) {
trans1(i, a[j]);
} else {
trans2(i, j != n + 1);
}
}
int ans = 0;
for (int i = max(0, n - len + 1); i <= n - 1; ++i) {
add(ans, dp[n][0][i]);
}
cout << ans << endl;
return 0;
}