description
给你两个矩阵(A_{i,j})和(B_{i,j}),你需要求一个排列(p_i),最小化$$sum_{i=1}^nA_{i,p_i} imes sum_{i=1}^nB_{i,p_i}$$
(nle70)
sol
最小乘积(KM)。
运用数形结合的思想,令(X=sum_{i=1}^nA_{i,p_i},Y=sum_{i=1}^nB_{i,p_i}),那么每一个排列就对应了笛卡尔坐标系中的一个点。
因为我们要最小化(XY),所以我们希望所求排列对应的点尽可能接近原点。
先分别以(A_{i,j})和(B_{i,j})作为权值做二分图最小权匹配求出横坐标/纵坐标最小的点,设为(A,B)。
我们希望能够找到一个点(C)使答案更优。
这个(C)应该尽量靠近坐标原点,或者说,应该尽量离(vec{AB})尽量远。
等同于是要求一个点(C),最小化(vec{AB} imes vec{AC})。
那么我们就可以以((B_x-A_x) imes B_{i,j}+(A_y-B_y) imes A_{i,j})作为新的权值做最小权匹配,这样就可以求出(C)了。
求出(C)以后,分别递归处理(A,C)与(C,B)。递归的边界就是(vec{AB} cdotvec{AC}=0)
因为是最小权匹配所以跑(KM)的时候边权要取反。
code
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int gi(){
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 100;
const int inf = 1e9;
struct point{
int x,y;
point(){x=y=0;}
point(int _x,int _y){x=_x,y=_y;}
point operator - (point b)
{return point(x-b.x,y-b.y);}
bool operator < (point b)
{return x*y<b.x*b.y;}
int operator * (point b)
{return x*b.y-y*b.x;}
}ans;
int n,a[N][N],b[N][N],val[N][N],mat[N],lv[N],rv[N],slack[N],vis[N],pre[N];
void aug(int s){
for (int i=0;i<=n;++i) slack[i]=inf,vis[i]=pre[i]=0;
int u=0;mat[u]=s;
do{
int now=mat[u],d=inf,nxt;vis[u]=1;
for (int v=1;v<=n;++v)
if (!vis[v]){
if (lv[now]+rv[v]-val[now][v]<slack[v])
slack[v]=lv[now]+rv[v]-val[now][v],pre[v]=u;
if (d>slack[v]) d=slack[v],nxt=v;
}
for (int i=0;i<=n;++i)
if (vis[i]) lv[mat[i]]-=d,rv[i]+=d;
else slack[i]-=d;
u=nxt;
}while (mat[u]);
while (u) mat[u]=mat[pre[u]],u=pre[u];
}
point KM(){
memset(mat,0,sizeof(mat));
for (int i=1;i<=n;++i){
lv[i]=-inf;rv[i]=0;
for (int j=1;j<=n;++j)
lv[i]=max(lv[i],val[i][j]);
}
for (int i=1;i<=n;++i) aug(i);
point res=point(0,0);
for (int i=1;i<=n;++i)
res.x+=a[mat[i]][i],res.y+=b[mat[i]][i];
if (res<ans) ans=res;return res;
}
void solve(point A,point B){
for (int i=1;i<=n;++i)
for (int j=1;j<=n;++j)
val[i][j]=-(B.x-A.x)*b[i][j]+(B.y-A.y)*a[i][j];
point C=KM();
if (A.x>B.x) exit(1);
if ((B-A)*(C-A)>=0) return;
solve(A,C);solve(C,B);
}
int main(){
int Case=gi();while (Case--){
n=gi();ans=point(40000,40000);
for (int i=1;i<=n;++i)
for (int j=1;j<=n;++j)
a[i][j]=gi();
for (int i=1;i<=n;++i)
for (int j=1;j<=n;++j)
b[i][j]=gi();
for (int i=1;i<=n;++i)
for (int j=1;j<=n;++j)
val[i][j]=-a[i][j];
point ma=KM();
for (int i=1;i<=n;++i)
for (int j=1;j<=n;++j)
val[i][j]=-b[i][j];
point mb=KM();
solve(ma,mb);
printf("%d
",ans.x*ans.y);
}
return 0;
}