某次在qbxt培训的时候学到的一个很有意思的算法。
前置知识
- 高斯消元
- 辗转相除法
问题
不妨来思考一下为什么传统的高斯消元法会存在精度问题。
我们知道,在四则运算中,唯一一个会损失精度的就是除法了。
那么高斯消元中都有哪些步骤用到了除法呢?
mat[i][i...n+1] /= mat[i][i]
: 主元系数化为1mat[j=i+1...n][k=i...n+1] -= mat[j][i] * mat[i][k]
: 消元
传统的优化方法
列主元法
传统的方法是直接选择mat[i][i]
作为当前主元的系数。列主元法在这一步上进行了一个优化:我们在mat[i...n][i]
中选择一个绝对值最大的作为主元系数。
新的方法
首先注意到,“主元系数化为1”其实是没有必要的,只需要把消元的式子改为mat[j=i+1...n][k=i...n+1] -= (mat[j][i] / mat[i][j]) * mat[i][k]
并在最后回代时/mat[i][i]
即可。
剩下的问题就是消元这一步了。
使用辗转相除法消元
一个辗转相除法的例子:
[(155,120)=(35,120)=(35,15)=(5,15)=(5,0)
]
规律:由原来的两个正整数变成了一个正整数和一个0。
回想高斯消元法中的加减消元这一步,我们的目的不就是把mat[j][i]
变为0吗?
算法
每次消元,我们可以把mat[i][i]
和mat[j][i]
两个数进行辗转相除,在两个数相减的时候,同时把这两个数所在的行也对应地相减。
这样操作完后,mat[i][i]
和mat[j][i]
必定会有一个变成0且另一个不是0。如果mat[i][i]
变为0了,那就把第i行和第j行交换。
另外还有需要注意消元过程中可以出现负数。
参考代码
仅供参考,没有判断无解和多解的情况
#include <cstdio>
#define ll long long
#define re register
#define il inline
#define gc getchar
#define pc putchar
template <class T>
void read(T &x) {
re bool f = 0;
re char c = gc();
while ((c < '0' || c > '9') && c != '-') c = gc();
if (c == '-') f = 1, c = gc();
x = 0;
while (c >= '0' && c <= '9') x = x * 10 + (c ^ 48), c = gc();
f && (x = -x);
}
template <class T>
void print(T x) {
if (x < 0) pc('-'), x = -x;
if (x >= 10) print(x / 10);
pc((x % 10) ^ 48);
}
template <class T>
void prisp(T x) {
print(x);
pc(' ');
}
template <class T>
void priln(T x) {
print(x);
pc('
');
}
int n;
int mat[105][105];
void swap(int &a, int &b) {
a ^= b ^= a ^= b;
}
void swp(int* a, int* b) {
for (int i = 1; i <= n + 1; ++i)
swap(a[i], b[i]);
}
void mul(int* a, int t) {
for (int i = 1; i <= n + 1; ++i)
a[i] *= t;
}
void sub(int* a, int* b, int t) {
for (int i = 1; i <= n + 1; ++i)
a[i] -= b[i] * t;
}
int ans[105];
int main() {
read(n);
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n + 1; ++j)
read(mat[i][j]);
for (int k = 1; k <= n; ++k) {
// 找主元
for (int i = k; i <= n; ++i) {
if (mat[i][k]) {
if (i != k) swp(mat[i], mat[k]);
break;
}
}
for (int i = k + 1; i <= n; ++i) {
int* a = mat[k];
int* b = mat[i];
// 先化成正数
if (a[k] < 0) mul(a, -1);
if (b[k] < 0) mul(b, -1);
// 辗转相除(迭代版)
while (a[k]) {
if (a[k] > b[k]) {
swp(a, b);
}
sub(b, a, b[k] / a[k]);
swp(a, b);
}
swp(a, b);
}
}
for (int i = n; i >= 1; --i) {
ans[i] = mat[i][n + 1];
for (int j = i + 1; j <= n; ++j)
ans[i] -= ans[j] * mat[i][j];
ans[i] /= mat[i][i];
}
for (int i = 1; i <= n; ++i) priln(ans[i]);
}