zoukankan      html  css  js  c++  java
  • 浅谈二分图匹配

    我觉得我现在还可以

    安利一下 rsx's blog

    正文

    本文权当作学习二分图匹配的一些笔记,个人风格极为严重,请勿深究语文

    匈牙利

    本文不介绍算法的步骤,因为可以 google(逃,写的都是一些自己的总结

    匈牙利的算法主要功能是可以做二分图的最大匹配。它是一种交替的找增广路的算法,其本质就是逆流修改(Singercoder 说的)

    其核心的就是 re 这个数组不会回溯 (当然每次要清)

    为什么呢,其实我觉得和当前当前弧优化有点联系。你想想,他要是回溯的话,都不是多项式算法了(笑)。但是其算法正确性也是有的。re 表示的是这个结点是否寻找过增广路。

    理论上最坏时间复杂度会到 (O(VE)) 还行吧。一般卡不了。

    算法伪代码

    main()
    {
    	last...
    	for 1 to n
    		fill re with 0
    		/* 每一次都要清空 re */
    		if find(i) ans++;
    	next...
    }
    find(u)
    {
    	for any edge of u
    		v = e.v;
    		/* 与暴力算法的区别,和当前弧异曲同工吧 */
    		if re[v]
    			continue
    		/* 打上标记 */
    		re[v] = 1
    		if p[v] == NULL or find(p[v])
    			/* 逻辑要清晰 */
    			p[v] = u
    			return true
    	return false
    }
    

    写的不是很好,但是还算清楚吧。毕竟代码实现的能力也很重要的(误)。

    然后通过这个操作,就可以求出二分图的最大匹配了。ヾ(≧∇≦*)ゝ

    当然了,谈到二分图匹配自然会有 Singercoder 冒出来,你匈牙利是 (O(nm)) ,我最大流跑是 (O(sqrt n m)) (没有严格证明,听dalao说的),客观的说,其实各有千秋吧。

    1. 运行时间:几乎相等,因为没有毒瘤可以把匈牙利卡到 (O(nm)),而且 ISAP 的常数肯定比匈牙利大呀.
    2. 代码实现难度:匈牙利完爆各种最大流
    3. 写完的成就感:最大流完爆匈牙利
    4. 运行空间:会卡这个?

    真正的代码

    bool find(int u) {
      for (int i = h[u]; i != -1; i = edge[i].lac)
      {
        int to = edge[i].to;
        if (re[to]) continue;
        re[to] = 1;
        if (!p[to] || find(p[to])) { p[to] = u; return 1; }
      }
      return 0;
    }
    

    然后我们就可以用它来搞各种与二分图有关的东西。

    下文介绍几种别的扩展

    1. 最少点覆盖
    2. 最少边覆盖
    3. 最大独立子集
    4. 最少不相交路径覆盖

    最小点覆盖

    反证一下,他应该就是最大匹配数,因为如果有边并未覆盖。那么他的左右端点应该都没被选,所以就不是最大匹配了,矛盾

    所以:最小点覆盖 = 最大匹配数

    模板:poj1325(有坑)

    最少边覆盖

    这么思考,贪心的想,先把所有匹配的边选了,然后对于剩下的点随便选一个边

    于是:最少边覆盖 = 点数 - 最大匹配数

    模板:poj3020

    最大独立子集

    理解为一个删点和与其相连的所有边,那么剩下的点就是独立子集。删去最小点覆盖即可,

    于是最大独立子集 = 点数 - 最小点覆盖

    模板:poj1466(据说数据有坑但是没有掉进去...但是还是很坑

    最少不相交路径覆盖

    ps:我实现这个一般都是网络流..因为 luogu 的模板要输出路径...

    这么想,我们最开始有 n 个不相交路径(n 个点),然后要进行连边,连尽可能多的边,就可以减少路径数...

    然后,把一个点分为前点,后点做二分图匹配即可。

    最少路径数 = 点数 - 最大匹配数

    模板:poj1422

    关于算法选择的闲话

    我们可以选择匈牙利来跑,因为 (O(nm)) 的时间复杂度的确可以接受,如果觉得不保险其实跑一个 isap 也未尝不可,所以我还是更喜欢最大流。

    KM算法

    KM算法是可以解决完备匹配下的最大权匹配,个人认为是对匈牙利进行了一番操作)

    完备匹配

    完备匹配就是要求匹配数为 n

    书面的说:所谓的完备匹配就是在二部图中,x 点集中的所有点都有对应的匹配且 y 点集中所有的点都有对应的匹配,则称该匹配为完备匹配。

    浅析算法正确性

    该算法是通过给每个顶点一个标号(叫做顶标)来把求最大权匹配的问题转化为求完备匹配的问题的。设顶点 (xi) 的顶标为 (lx[i]),顶点 (y_j) 的顶标为 (ly[j]) ,边权为 (mp[i][j]) 。在算法执行过程中的任一时刻,对于任一条Edge (<i,j>)(lx[i] + ly[j] ge mp[i][j]) 始终成立。

    给出一条定理:若由二分图中所有满足(lx[i] + ly[j] = mp[i][j]) 的边 (<i,j>) 构成的子图(称做相等子图)有完备匹配,那么这个完备匹配就是二分图的最大权匹配。

    这是显然的,因为对于每个匹配他其他的选择的边权一定满足 (mp[i][j] le lx[i] + ly[j]) 要保证其有最大值,只能选相等子图的

    其最佳答案即为 (sum lx[i] +ly[i]) 所以

    初始时为了使 (lx[i] + ly[j] ge mp[i][j]) 恒成立,令 (lx[i]) 为所有与顶点 (x_i) 关联的边的最大权,(ly[j] = 0)。如果当前的相等子图没有完备匹配,就按下面的方法修改顶标以使扩大相等子图,直到相等子图具有完备匹配为止。

    我们求当前相等子图的完备匹配失败了,是因为对于某个 (x) 顶点,我们找不到一条从它出发的交错路。这时我们获得了一棵交错树,它的叶子结点全部是 (x) 顶点。我们把交错树中 (x) 顶点的顶标全都减小某个值 (d), (y) 顶点的顶标全都增加同一个值 (d) ,那么我们会发现:

    1)两端都在交错树中的边 (<i,j>) (lx + ly)(以下简称xy)的值没有变化。也就是说,它原来属于相等子图,仍属于相等子图。

    2)两端都不在交错树中的边 (<i, j>)(xy) 都没有变化。也就是说,它原来属于(或不属于)相等子图,仍属于(或不属于)相等子图。

    3)(x) 端不在交错树中,(y) 端在交错树中的边 (<i,j>) ,它的 (xy) 的值有所增大。它原来不属于相等子图,仍不属于相等子图。

    4)(x) 端在交错树中,(y) 端不在交错树中的边 (<i,j>) ,它的 (xy) 的值有所减小。它原来不属于相等子图,可能进入了相等子图,因而使相等子图得到了扩大。

    5)到最后,(x) 端每个点至少有一条线连着,(y) 端每个点有一条线连着,说明最后补充完的相等子图一定有完备匹配。所以,最大权子图子图是一步一步变大的。

    而显然的为了满足性质有: (d = min(lx[i]+ly[j]-mp[i][j])) 其中 $ x_i$ 在交替树里,(y_i) 不在。

    为啥?我们要扩大相等子图。

    对于 (d) 的求解可以在求解交替树时完成

    为啥我上面的都划掉了,ssinger 提出了一种反例,观察第三个,我们可以知道,还会有的边原来在子图里,但是他的x端减小了!所以就会出子图,但对算法的正确性没有什么影响,因为

    代码

    bool find(int u)
    {
    	visx[u] = 1;
    	for(int i = 1; i <= n; ++i)
    	{
    		if(visy[i]) continue;
    		if(lx[u] + ly[i] == mp[u][i])
    		{
    			visy[i] = 1;
    			if(!p[i] || find(p[i]))
    			{
    				p[i] = u;
    				return 1;
    			}
    		}
    		else d = min(d, lx[u] + ly[i] - mp[u][i]);
            /* 在这里计算d */
    	}
    	return 0;
    }
    
    /* 当然了,如果可以匹配成功,那就不用d了 */
    
    main:
    
    	for(int k = 1; k <= n; ++k)
    	{
    		while(1)
    		{
    			memset(visx, 0, sizeof visx);
    			memset(visy, 0, sizeof visy);
    			d = 1 << 30;
    			if(find(k)) break;
    			for(int i = 1; i <= n; ++i)
    			{
    				if(visx[i]) lx[i] -= d;
    				if(visy[i]) ly[i] += d;
    			}
    		}
    	}
    

    简单口胡下时间复杂度:

    循环 (n) 次 每次找增广修改 (n) 次顶标 用 (n) 次算 (d) 所以为 (O(n^3)) 但是怎么会跑到...

    之前口胡错了。应该是 (n^4) 因为不是 (n) 次算 (d)

    所以似乎之前写的不是多好用了有一个 (d) 是全程量,但是这并不是很好,因为计算的不是很准确。在下文会介绍一个使用 (slac[ ]) 数组来更精准的来算出delta(lamda)

    话说,为啥不用费用流来跑,因为 ekzkw 都挺慢的,KMzkw 复杂度在上界上都为 (O(n^4)),但是 KM 一般会更快

    在随机数据下,使用 slac 数组可以在随机数据下平均复杂度约为 (O(n^3))

    介绍一下,定义 (slac[v]) 表示对于你 (find()) 函数中把 (d) 改为 (slac[v]) 主要原因是要精准,那么我们的最后的 (d) 是啥,应该是 (!visy[v]) 时里的最小值。就是说你要填几个边才行..就是要找不在交替树的最小 slac 从 x 加边

    代码:(hdu3435)

    #define _CRT_SECURE_NO_WARNINS
    
    #include <bits/stdc++.h>
    
    using namespace std;
    
    template <typename T>
    inline T read()
    {
    	T x = 0;
    	char ch = getchar();
    	bool f = 0;
    	while(ch < '0' || ch > '9')
    	{
    		f = (ch == '-');
    		ch = getchar();
    	}
    	while(ch <= '9' && ch >= '0')
    	{
    		x = (x << 1) + (x << 3) + (ch - '0');
    		ch = getchar();
    	}
    	return  f? -x : x;
    }
    
    template <typename T>
    void put(T x)
    {
    	if(x < 0)
    	{
    		x = -x;
    		putchar('-');
    	}
    	if(x < 10) {
    		putchar(x + 48);
    		return;
    	}
    	put(x / 10);
    	putchar(x % 10 + 48);
    	return ;
    }
    
    const int Maxn = 1001, inf = 0x3f3f3f3f;
    
    int cnt, h[Maxn], lx[Maxn], ly[Maxn], slac[Maxn], p[Maxn], t, n, m;
    
    bool visx[Maxn], visy[Maxn];
    
    struct Edge
    {
    	int to, lac, wg;
    	void insert(int x, int y, int z) { to = y; lac = h[x]; h[x] = cnt++; wg = z; }
    }edge[Maxn << 6];
    
    void add_edge(int w, int v, int u)
    {
    	edge[cnt].insert(u, v, w);
    	edge[cnt].insert(v, u, w);
    	lx[u] = max(lx[u], w);
    	lx[v] = max(lx[v], w);
    }
    
    bool find(int u)
    {
    	visx[u] = 1;
    	for(int i = h[u]; i != -1; i = edge[i].lac)
    	{
    		int to = edge[i].to;
    		if(visy[to]) continue;
    		// 注意这里的位置
    		if(lx[u] + ly[to] == edge[i].wg)
    		{
    			visy[to] = 1;
    			if(!p[to] || find(p[to]))
    			{
    				p[to] = u;
    				return 1;
    			}
    		}
    		else slac[to] = min(slac[to], lx[u] + ly[to] - edge[i].wg);
    	}
    	return 0;
    }
    
    int KM()
    {
    	memset(p, 0, sizeof p);
    	for(int k = 1; k <= n; ++k)
    	{
    		memset(slac, 63, sizeof slac);
    		while(1)
    		{
    			memset(visx, 0, sizeof visx);
    			memset(visy, 0, sizeof visy);
    			int minz = 0x3f3f3f3f;
    			if(find(k)) break;
    			for(int i = 1; i <= n; ++i) if(!visy[i] && minz > slac[i]) minz = slac[i];
    			if(minz == 0x3f3f3f3f) return -1;
    			for(int i = 1; i <= n; ++i)
    			{
    				if(visx[i]) lx[i] -= minz;
    				if(visy[i]) ly[i] += minz;
    				else slac[i] -= minz;
    			}
    		}
    	}
    	int ans = 0;
    	for(int i = 1; i <= n; ++i) ans += lx[i] + ly[i];
    	return -ans;
    }
    
    int main()
    {
    	freopen("in.txt", "r", stdin);
    	int cnt1 = 0;
    	t = read <int> ();
    	while(t--)
    	{
    		n = read <int> ();
    		m = read <int> ();
    		memset(h, -1, sizeof h); cnt = 0;
    		for(int i = 1; i <= n; ++i) lx[i] = -inf, ly[i] = 0;
    		for(int i = 1; i <= m; ++i) add_edge(-read <int> (), read <int> () + 1, read <int> () + 1);
    		int ans = KM();
    		if(ans == -1) printf("Case %d: NO
    ", ++cnt1);
    		else printf("Case %d: %d
    ", ++cnt1, ans);
    	}
    	return 0;
    }
    

    关于判无解,我本人提出了一种想法,就是先跑匈牙利(笑),singercoder 提出了一种想法,当 d = inf 时,但是由于他那个神奇的常数问题(所以导致了死循环)....神奇吧....其实这个想法是对的-.-

    所以真的是很神奇,KM 的坑实在是太多了。

    当然了,由于 (O(n^4)) 的复杂度无法接受,有的人研究出了一种 bfs 的操作

    这个码是 uoj#80

    #define _CRT_SECURE_NO_WARNINS
    
    #include <bits/stdc++.h>
    
    using namespace std;
    
    template <typename T>
    inline T read()
    {
    	T x = 0;
    	char ch = getchar();
    	bool f = 0;
    	while(ch < '0' || ch > '9')
    	{
    		f = (ch == '-');
    		ch = getchar();
    	}
    	while(ch <= '9' && ch >= '0')
    	{
    		x = (x << 1) + (x << 3) + (ch - '0');
    		ch = getchar();
    	}
    	return  f? -x : x;
    }
    
    template <typename T>
    void put(T x)
    {
    	if(x < 0)
    	{
    		x = -x;
    		putchar('-');
    	}
    	if(x < 10) {
    		putchar(x + 48);
    		return;
    	}
    	put(x / 10);
    	putchar(x % 10 + 48);
    	return ;
    }
    
    const int Maxn = 404;
    
    int mp[Maxn][Maxn], n, bl, br, m, lx[Maxn], ly[Maxn], link[Maxn], p[Maxn], slac[Maxn], pre[Maxn], mat[Maxn];
    
    bool visx[Maxn], visy[Maxn];
    
    // 把增广路上的取反
    
    void aug(int x)
    {
    	if(!x) return;
    	link[x] = pre[x];
    	aug(mat[pre[x]]);
    	mat[pre[x]] = x;
    	return;
    }
    
    // 通过 bfs 来寻找增广路
    
    void bfs(int u)
    {
    	int delta;
    	memset(slac, 63, sizeof slac);
    	memset(visx, 0, sizeof visx);
    	memset(visy, 0, sizeof visy);
    	queue <int> q;
    	q.push(u);
    	while(1)
    	{
    		// 通过广搜来寻找
    		while(!q.empty())
    		{
    			int fr =q.front();
    			q.pop();
    			visx[fr] = 1;
    			for(int i = 1; i <= n; ++i)
    			{
    				// 这里就跟dfs的差不多,但是需要用pre来记录路径
    				// 在这里 visy 的位置十分关键
    				if(visy[i]) continue;
    				delta = lx[fr] + ly[i] - mp[fr][i];
    				if(!delta)
    				{
    					visy[i] = 1;
    					pre[i] = fr;
    					if(!link[i]) { aug(i); return ; }
    					q.push(link[i]);
    				}
    				else if(slac[i] > delta) slac[i] = delta, pre[i] = fr;
    			}
    		}
    		delta = 0x3f3f3f3f;
    		for(int i = 1; i <= n; ++i) if(!visy[i]) delta = min(delta, slac[i]);
    		if(delta == 0x3f3f3f3f) return;
    		for(int i = 1; i <= n; ++i)
    		{
    			if(visx[i]) lx[i] -= delta;
    			// 左边的点顶标减去 delta
    			if(visy[i]) ly[i] += delta;
    			// 遍历的点加上 delta
    			else slac[i] -= delta;
    			// 记住没有打标记的点就 slac -= delta
    		}
    		// 把新增的点添到队列中
    		for(int i = 1; i <= n; ++i)
    		{
    			if(!visy[i] && !slac[i])
    			{
    				visy[i] = 1;
    				if(!link[i]) { aug(i); return ; }
    				// 如果可以就返回
    				q.push(link[i]);
    			}
    		}
    	}
    }
    
    void KM()
    {
    	for(int i = 1; i <= n; ++i)
    	{
    		lx[i] = 0;
    		for(int j = 1; j <= n; ++j) lx[i] = max(lx[i], mp[i][j]);
    	}
    	for(int i = 1; i <= n; ++i)
    		bfs(i);
    	long long ans = 0;
    	for(int i = 1; i <= br; ++i)
    	{
    		if(!mp[link[i]][i]) continue;
    		ans += mp[link[i]][i];
    		p[link[i]] = i;
    	}
    	put <long long> (ans); putchar('
    ');
    	for(int i = 1; i <= bl; ++i) put <int> (p[i]), putchar(' ');
    	return ;
    }
    
    int main()
    {
    
    #ifdef _DEBUG
    	freopen("in.txt", "r", stdin);
    #endif
    	bl = read <int> ();
    	br = read <int> ();
    	n = max(bl, br);
    	m = read <int> ();
    	while(m--) mp[read <int> ()][read <int> ()] = read <int> ();
    	KM();
    	return 0;
    }
    

    这个时间复杂度是 (O(n^3))

    我没有那么好的天份 ,我只是多下点力气

    关于时间复杂度的证明,好像 algorithm 里给了一个主定理之类的东西吧。去学习下,要肝日校了,可能会比较晚写

  • 相关阅读:
    oracle比较常用的函数
    生成GUID
    字符串操作
    Visual Studio常用快捷键
    c#保存异常日志
    c#的Trim方法
    c#之文件操作
    Python可视化库matplotlib.pyplot里contour与contourf的区别
    python linspace
    神经网络实现连续型变量的回归预测(python)
  • 原文地址:https://www.cnblogs.com/zhltao/p/12549489.html
Copyright © 2011-2022 走看看