题解
题目可以看作有一个长度为 (n) 的 (01) 串,有 (k) 位为 (1),其余为 (0)。每次操作相当于区间 (oplus 1),求变成全 (0) 串的最小操作次数。
操作为区间修改,那么考虑对原 (01) 串进行差分。
例如样例,原串是 (color{red}010001color{red}0),这里设第 (0) 位和第 (n+1) 位都为 (0),用红色数字表示。
则差分之后是 (color{red}011001color{red}1),即设原串第 (i) 位为 (a_i),差分后新串第 (i) 位为 (d_i),则 (a_i=d_0oplus d_1oplus d_2oplus cdots oplus d_i)。
这个时候我们考虑,对原串 ([l,r]) 进行区间 (oplus1),其实就相当于令新串的 (d_lgets d_loplus1,d_{r+1}gets d_{r+1}oplus 1)。
变化有 (4) 种情况:
- 若 (d_l=d_{r+1}=0),则相当于把这两位变成 (1),增加 (2) 个 (1)。
- 若 (d_l=d_{r+1}=1),则相当于把这两位变成 (0),减少 (2) 个 (0)。
- 若 (d_l=1,d_{r+1}=0),则相当于把 (l) 上的 (1) 移到 (r+1),(1) 数量不变。
- 若 (d_l=0,d_{r+1}=1),则相当于把 (r+1) 上的 (1) 移到 (l),(1) 数量不变。
那我们可以对任意 (i) 和 (i+b_j) 连一条边权为 (1) 的无向边,则新串中 (d_lgets d_loplus 1,d_{r+1}gets d_{r+1}gets 1) 的最小操作次数为 (l) 到 (r+1) 的最短路,记作 ( extit{dis}(l,r+1))。则对原串 ([l,r]) 进行区间 (oplus1) 的最小操作次数为 ( extit{dis}(l,r+1))。因为边权为 (1),所以求最短路跑 (n) 次 BFS 即可。
容易得到,新串变为全 (0) 串是原串变为全 (0) 串的充要条件。
考虑 (k) 很小,实际上新串中初始最多只有 (2k) 个 (1),而其他的 (0) 我们根本就不用管,也就是说只用计这 (2k) 个 (1)。则新串只有 (2^{2k}) 种状态,直接上状压 dp。
设 (f(S)) 为新串状态为 (S) 时的最小操作次数,预处理出新串的初始状态 ( extit{st}),则 (f( extit{st})=0)。
显然每次转移,进行上面的第 (1,3,4) 种变化是没有意义的,所以只需要做第 (2) 种变化(即转移中,只转移第 (2) 种变化,实际上已经包括了其他 (3) 种)。
设 (T=S-2^{i-1}-2^{j-1}),则 (f(T)=min {f(S)+ extit{dis}(i,j)})。答案即为 (f(0))。
同样的,我们在 BFS 的时候,只用从初始值为 (1) 的位开始搜,那么最多跑 (2k) 遍 BFS。
时间复杂度为 (O(2knm+2^{2k}(2k)^2))。
实际上每次转移的时候,(i) 可以只取 (S) 中 (1) 的最低位(无论如何都要转移这个最低位,那么可以先转移)。
时间复杂度为 (O(2knm+2^{2k} (2k)))。
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAX_N (40000 + 5)
#define MAX_M (64 + 5)
#define MAX_K (8 + 5)
#define MAX_S ((1 << 16) + 5)
using std::min;
int n, m, k;
int a[MAX_K * 2], len;
int b[MAX_M];
int dis[MAX_N];
int g[MAX_K * 2][MAX_K * 2];
int q[MAX_N], l, r;
bool vis[MAX_N];
int f[MAX_S];
void BFS() {
int u, v;
for (int I = 1; I <= len; ++I) {
memset(dis, 0x3f, sizeof dis);
memset(vis, 0, sizeof vis);
l = r = 1;
dis[a[I]] = 0;
vis[a[I]] = 1;
q[1] = a[I];
while (l <= r) {
u = q[l++];
for (int i = 1; i <= m; ++i) {
v = u - b[i];
if (v >= 1 && !vis[v]) {
dis[v] = dis[u] + 1;
vis[v] = 1;
q[++r] = v;
}
v = u + b[i];
if (v <= n + 1 && !vis[v]) {
dis[v] = dis[u] + 1;
vis[v] = 1;
q[++r] = v;
}
}
}
for (int i = 1; i <= len; ++i)
g[I][i] = dis[a[i]];
}
}
int main() {
scanf("%d%d%d", &n, &k, &m);
int pos;
for (int i = 1; i <= k; ++i) {
scanf("%d", &pos);
vis[pos] ^= 1;
vis[pos + 1] ^= 1;
}
for (int i = 1; i <= n + 1; ++i)
if (vis[i]) a[++len] = i;
for (int i = 1; i <= m; ++i)
scanf("%d", &b[i]);
BFS();
const int lim = (1 << len) - 1;
memset(f, 0x3f, sizeof f);
f[lim] = 0;
int tmp, T;
for (int S = lim; S; --S) {
tmp = S;
pos = 1;
while ((tmp & 1) == 0) {
++pos;
tmp >>= 1;
}
for (int i = pos + 1; i <= len; ++i) {
if ((S & 1 << i - 1) == 0) continue;
T = S ^ 1 << pos - 1 ^ 1 << i - 1;
f[T] = min(f[T], f[S] + g[pos][i]);
}
}
printf("%d", f[0]);
return 0;
}