好久没做网络流了……
定义四个矩阵分别为a[i][j], b[i][j], c[i][j], d[i][j]
定义一个点为(x),源点向(x)连容量为a[i][j]
边,(x)向汇点连容量为b[i][j]
的边
因为相邻的选择同一科目会有加成,所以对于c[i][j]
,我们可以新建一个节点(y),源点向(y)连容量为c[i][j]
的边,(y)向它周围的点连容量为(inf)的边。对于d[i][j]
也是同理,(y)向汇点连容量为c[i][j]
的边,周围的点向(y)连容量为(inf)的边
将四个矩阵的价值全加起来,减去最小割即为答案
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
#define LL long long
#define INF 2147483647
int read() {
int k = 0, f = 1; char c = getchar();
while(c < '0' || c > '9') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
k = k * 10 + c - 48, c = getchar();
}
return k * f;
}
struct zzz {
int t, len, nex;
}e[500010 << 1]; int head[100010], tot = 1;
void add(int x, int y, int z) {
e[++tot].t = y;
e[tot].len = z;
e[tot].nex = head[x];
head[x] = tot;
e[++tot].t = x;
e[tot].len = 0;
e[tot].nex = head[y];
head[y] = tot;
}
int s, t, n, m, vis[100010];
bool bfs() {
queue <int> q; q.push(s);
memset(vis, 0, sizeof(vis)); vis[s] = 1;
while(!q.empty()) {
int k = q.front(); q.pop();
for(int i = head[k]; i; i = e[i].nex) {
if(!vis[e[i].t] && e[i].len) {
vis[e[i].t] = vis[k] + 1; q.push(e[i].t);
if(e[i].t == t) return 1;
}
}
}
return vis[t];
}
int dfs(int flow, int pos) {
if(!flow || pos == t) return flow;
int fl, rest = 0;
for(int i = head[pos]; i; i = e[i].nex) {
if(vis[e[i].t] == vis[pos] + 1 && (fl = dfs(min(e[i].len, flow-rest), e[i].t))) {
e[i].len -= fl, e[i^1].len += fl, rest += fl;
if(rest == flow) return rest;
}
}
if(rest < flow) vis[pos] = 0;
return rest;
}
int dinic() {
int ans = 0;
while(bfs()) ans += dfs(INF, s);
return ans;
}
inline int calc(int x, int y) {
return (x - 1) * m + y + 2;
}
int fx[5] = {0, 1, 0, -1, 0},
fy[5] = {0, 0, 1, 0, -1};
int sum;
int main() {
n = read(), m = read(), s = 1, t = 2;
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j) {
int x = read(); sum += x;
add(s, calc(i, j), x);
}
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j) {
int x = read(); sum += x;
add(calc(i, j), t, x);
}
int pos = calc(n, m);
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j) {
int val = read();
sum += val;
add(s, ++pos, val);
for(int k = 0; k <= 4; ++k) {
if(i+fx[k] < 1 || i+fx[k] > n || j+fy[k] < 1 || j+fy[k] > m) continue;
int y = calc(i+fx[k], j+fy[k]);
add(pos, y, INF);
}
}
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j) {
int val = read();
sum += val;
add(++pos, t, val);
for(int k = 0; k <= 4; ++k) {
if(i+fx[k] < 1 || i+fx[k] > n || j+fy[k] < 1 || j+fy[k] > m) continue;
int y = calc(i+fx[k], j+fy[k]);
add(y, pos, INF);
}
}
printf("%d", sum - dinic());
return 0;
}