题目大意:给定一个 N*M 的矩阵,现从 N 行中选出 R 行,M 列中选出 C 列,构成一个 R*C 子矩阵,求这个子矩阵相邻元素差的绝对值之和的最小值是多少。
题解:
发现是对行和列的组合生成,若直接暴力的话,时间复杂度为 (O({n choose r}{m choose c}nm))。
代码如下
#include <bits/stdc++.h>
using namespace std;
const int maxn=20;
int n,m,r,c,ans,mp[maxn][maxn];
vector<int> row,col;
inline void calc(){
int ret=0;
for(int i=0;i<r;i++)for(int j=0;j<c-1;j++)ret+=abs(mp[row[i]][col[j]]-mp[row[i]][col[j+1]]);
for(int i=0;i<r-1;i++)for(int j=0;j<c;j++)ret+=abs(mp[row[i]][col[j]]-mp[row[i+1]][col[j]]);
ans=min(ans,ret);
}
void dfsc(int now){
if(col.size()>c||col.size()+m-now+1<c)return;
if(now==m+1){
calc();
return;
}
col.push_back(now);
dfsc(now+1);
col.pop_back();
dfsc(now+1);
}
void dfsr(int now){
if(row.size()>r||row.size()+n-now+1<r)return;
if(now==n+1){
dfsc(1);
return;
}
row.push_back(now);
dfsr(now+1);
row.pop_back();
dfsr(now+1);
}
void read_and_parse(){
scanf("%d%d%d%d",&n,&m,&r,&c);
for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)scanf("%d",&mp[i][j]);
}
void solve(){
ans=1<<30;
dfsr(1);
printf("%d
",ans);
}
int main(){
read_and_parse();
solve();
return 0;
}
进一步考虑,发现若枚举出了 r 行,那么对于每一列来说,可以抽象成下列问题,即:给定一个长度为 N 的序列,现从序列中选出 M 个元素组成的子序列,使得这 M 个元素中相邻两个元素差的绝对值之和最小。发现是一个 dp,对于矩阵来说,将矩阵转化成序列即可,dp 的时间复杂度为 (O(n^3))。总的时间复杂度为 (O({n choose r}m^3))。
代码如下
#include <bits/stdc++.h>
#define cls(a,b) memset(a,b,sizeof(a))
using namespace std;
const int maxn=20;
int n,m,r,c,ans,mp[maxn][maxn];
vector<int> row;
int dp[maxn][maxn],extra[maxn],cost[maxn][maxn];
inline void calc(){
cls(dp,0x3f),cls(extra,0),cls(cost,0);
for(int i=1;i<=m;i++)for(int j=i+1;j<=m;j++)for(auto ro:row)cost[i][j]+=abs(mp[ro][i]-mp[ro][j]);
for(int co=1;co<=m;co++)for(int i=0;i<row.size()-1;i++)extra[co]+=abs(mp[row[i]][co]-mp[row[i+1]][co]);
for(int i=0;i<=m;i++)dp[i][0]=0;
for(int i=1;i<=m;i++)
for(int j=1;j<=i;j++)
for(int k=j-1;k<i;k++)
dp[i][j]=min(dp[i][j],dp[k][j-1]+cost[k][i]+extra[i]);
for(int i=1;i<=m;i++)ans=min(ans,dp[i][c]);
}
void dfs(int now){
if(row.size()>r||row.size()+n-now+1<r)return;
if(now==n+1){
calc();
return;
}
row.push_back(now);
dfs(now+1);
row.pop_back();
dfs(now+1);
}
void read_and_parse(){
scanf("%d%d%d%d",&n,&m,&r,&c);
for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)scanf("%d",&mp[i][j]);
}
void solve(){
ans=1<<30;
dfs(1);
printf("%d
",ans);
}
int main(){
read_and_parse();
solve();
return 0;
}