最近集中学习了一下矩阵树定理,自己其实还是没有太明白原理(证明)类的东西,但想在这里总结一下应用中的一些细节,矩阵树定理的一些引申等等。
首先,矩阵树定理用于求解一个图上的生成树个数。实现方式是:(A)为邻接矩阵,(D)为度数矩阵,则基尔霍夫(Kirchhoff)矩阵即为:(K = D - A)。具体实现中,记 (a) 为Kirchhoff矩阵,则若存在 (E(u, v)) ,则(a[u][u] ++, a[v][v] ++, a[u][v] --, a[v][u] --) 。即(a[i][i]) 为 (i) 点的度数,(a[i][j]) 为 (i, j)之间边的条数的相反数。
这样构成的矩阵的行列式的值,就为生成树的个数。而求解行列式的快速方法为使用高斯消元进行消元消处上三角矩阵,则有对角线上的值的乘积 = 行列式的值。一般而言求解生成树个数的题目数量会非常庞大,需要取模处理。取模处理中,不能出现小数,于是使用辗转相除法:(其中因为消的是行列式,所以与消方程有所不同。交换两行行列式的值变号,且消元只能将一行的数 * k 之后加到别的行上。)
int Gauss() { int ans = 1; for(int i = 1; i < tot; i ++) { for(int j = i + 1; j < tot; j ++) while(f[j][i]) { int t = f[i][i] / f[j][i]; for(int k = i; k < tot; k ++) f[i][k] = (f[i][k] - t * f[j][k] + mod) % mod; swap(f[i], f[j]); ans = - ans; } ans = (ans * f[i][i]) % mod; } return (ans + mod) % mod; }
变元矩阵树定理:求所有生成树的总边积的和。和矩阵树的求法相同,不过行列式中 (a[i][i]) 记录的是总的边权和,(a[i][j]) 记录 (i, j) 之间边权的相反数。
以下为几道题目:
1.HEOI2015 小Z的房间 2.SHOI2016 黑暗前的幻想乡
3.SDOI2014 重建 4.JSOI2008 最小生成树计数
1.HEOI2015 小Z的房间(妥妥的模板题一个)
#include <bits/stdc++.h> using namespace std; #define maxn 90 #define int long long #define mod 1000000000 int n, m, f[maxn][maxn]; int tot, Map[maxn][maxn]; int read() { int x = 0, k = 1; char c; c = getchar(); while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); } while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x * k; } void add(int x, int y) { if(x > y) return; f[x][x] ++, f[y][y] ++; f[x][y] --, f[y][x] --; } int Gauss() { int ans = 1; for(int i = 1; i < tot; i ++) { for(int j = i + 1; j < tot; j ++) while(f[j][i]) { int t = f[i][i] / f[j][i]; for(int k = i; k < tot; k ++) f[i][k] = (f[i][k] - t * f[j][k] + mod) % mod; swap(f[i], f[j]); ans = - ans; } ans = (ans * f[i][i]) % mod; } return (ans + mod) % mod; } signed main() { n = read(), m = read(); for(int i = 1; i <= n; i ++) { char c; for(int j = 1; j <= m; j ++) { cin >> c; if(c == '.') Map[i][j] = ++ tot; } } for(int i = 1; i <= n; i ++) for(int j = 1; j <= m; j ++) { int tem, u; if(!(u = Map[i][j])) continue; if(tem = Map[i - 1][j]) add(u, tem); if(tem = Map[i + 1][j]) add(u, tem); if(tem = Map[i][j - 1]) add(u, tem); if(tem = Map[i][j + 1]) add(u, tem); } printf("%lld ", Gauss()); return 0; }
2.SHOI2016黑暗前的幻想乡
容斥+矩阵树定理。与模板的不同之处在于每一家公司都要参与修建,则合法方案数 = 总的方案数 - 有一个公司未修建的方案数 + 有两个公司未修建的方案数……暴力重构矩阵求解即可。
#include <bits/stdc++.h> using namespace std; #define ll long long const int mod = 1000000007; int n; ll g[20][20]; vector<pair<int , int > > q[20]; int read() { int x = 0, k = 1; char c; c = getchar(); while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); } while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x * k; } int Gauss() { ll ans = 1; for(int i = 1; i < n; i ++) { for(int j = i + 1; j < n; j ++) while(g[j][i]) { ll t = g[i][i] / g[j][i]; for(int k = i; k < n; k ++) g[i][k] = (g[i][k] - g[j][k] * t) % mod; swap(g[i], g[j]); ans = -ans; } ans = (ans * g[i][i]) % mod; if(!ans) return 0; } return (ans + mod) % mod; } int main() { n = read(); for(int i = 1; i < n; i ++) { int m = read(); for(int j = 1; j <= m; j ++) { int x = read(), y = read(); q[i].push_back(make_pair(x, y)); } } int ans = 0, CNST = 1 << (n - 1); for(int i = 0; i < CNST; i ++) { int cnt = 0; memset(g, 0, sizeof(g)); for(int j = 1; j < n; j ++) if(i & (1 << (j - 1))) { for(int k = 0; k < q[j].size(); k ++) { int x = q[j][k].first, y = q[j][k].second; g[x][x] ++, g[y][y] ++; g[x][y] --, g[y][x] --; } cnt ++; } if((n - cnt) & 1) ans = (ans + Gauss()) % mod; else ans = (ans - Gauss() + mod) % mod; } printf("%d ", ans); return 0; }
3.SDOI2014重建
化式子 + 变元矩阵树定理。将概率的式子写出来变形即可得到矩阵树定理求 (prod frac{p(u, v)}{1 - p(u, v)})
#include <bits/stdc++.h> using namespace std; #define maxn 100 #define db double #define eps 0.000001 int n; db ans = 1.0, a[maxn][maxn]; db Gauss(int n) { db ans = 1.0; for(int i = 1; i <= n; i ++) { for(int j = i + 1; j <= n; j ++) { int t = i; if(fabs(a[j][i]) > fabs(a[t][i])) t = j; if(t != i) swap(a[t], a[i]), ans = -ans; } for(int j = i + 1; j <= n; j ++) { db t = a[j][i] / a[i][i]; for(int k = i; k <= n; k ++) a[j][k] -= t * a[i][k]; } ans *= a[i][i]; } return fabs(ans); } int main() { scanf("%d", &n); for(int i = 1; i <= n; i ++) for(int j = 1; j <= n; j ++) scanf("%lf", &a[i][j]); for(int i = 1; i <= n; i ++) for(int j = 1; j <= n; j ++) { db t = fabs(1.0 - a[i][j]) < eps ? eps : (1.0 - a[i][j]); if(i < j) ans *= t; a[i][j] = a[i][j] / t; } for(int i = 1; i <= n; i ++) for(int j = 1; j <= n; j ++) if(i != j) { a[i][i] += a[i][j], a[i][j] = -a[i][j]; } printf("%.10lf ", Gauss(n - 1) * ans); return 0; }
4.JSOI2008最小生成树计数
这题虽然最早年,然而也最强啊……个人认为这位博主解释得很好了 Z-Y-Y-S的博客
两个性质 mark 一下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> using namespace std; #define maxn 200 #define mod 31011 int n, m, ans = 1, tmp[maxn]; int sum, fa[maxn], set[maxn]; int a[maxn][maxn]; struct edge { int u, v, w; }E[maxn * 20], e[maxn * 20]; int read() { int x = 0, k = 1; char c; c = getchar(); while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); } while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x * k; } bool cmp(edge a, edge b) { return a.w < b.w; } int find(int x) { return set[x] == x ? x : set[x] = find(set[x]); } int find2(int x) { return fa[x] == x ? x : fa[x] = find2(fa[x]); } int Gauss(int n) { int ans = 1; for(int i = 1; i <= n; i ++) for(int j = 1; j <= n; j ++) a[i][j] = (a[i][j] + mod) % mod; for(int i = 1; i <= n; i ++) { for(int j = i + 1; j <= n; j ++) while(a[j][i]) { int t = a[i][i] / a[j][i]; for(int k = i; k <= n; k ++) a[i][k] = (a[i][k] - 1ll * t * a[j][k] % mod + mod) % mod; swap(a[i], a[j]); ans = - ans; } ans = 1ll * ans * a[i][i] % mod; } return (ans + mod) % mod; } void Cal(int S, int T) { int cnt = 0; for(int i = S; i <= T; i ++) { e[i] = E[i]; int p = find(e[i].u), q = find(e[i].v); e[i].u = p, e[i].v = q; if(p == q) continue; tmp[++ cnt] = p, tmp[++ cnt] = q; } sort(tmp + 1, tmp + 1 + cnt); cnt = unique(tmp + 1, tmp + cnt + 1) - tmp - 1; memset(a, 0, sizeof(a)); for(int i = 1; i <= cnt; i ++) fa[i] = i; for(int i = S; i <= T; i ++) { if(e[i].u == e[i].v) continue; int p = find(e[i].u), q = find(e[i].v); if(p != q) -- sum, set[p] = q; int u = lower_bound(tmp + 1, tmp + cnt + 1, e[i].u) - tmp; int v = lower_bound(tmp + 1, tmp + cnt + 1, e[i].v) - tmp; a[u][u] ++, a[v][v] ++; a[u][v] --, a[v][u] --; p = find2(u), q = find2(v); if(p != q) fa[p] = q; } for(int i = 2; i <= cnt; i ++) if(find2(i) != find2(i - 1)) { int p = find2(i), q = find2(i - 1); a[p][p] ++, a[q][q] ++; a[p][q] --, a[q][p] --; fa[p] = q; } ans = 1ll * ans * Gauss(cnt - 1) % mod; } int main() { n = read(), m = read(); for(int i = 1; i <= m; i ++) E[i].u = read(), E[i].v = read(), E[i].w = read(); sort(E + 1, E + 1 + m, cmp); for(int i = 1; i <= n; i ++) set[i] = i; sum = n; for(int i = 1, j; i <= m; i = j) { for(j = i; j <= m; j ++) if(E[j].w != E[i].w) break; if(j - i > 1) Cal(i, j - 1); else { int p = find(E[i].u), q = find(E[i].v); if(p != q) set[p] = q; sum --; } } if(sum > 1) printf("0"); else printf("%d ", ans); return 0; }