#define MAX_MATLAB_N (128)
void swap(double *a,double *b)
{
double tmp = *a;
*a = *b;
*b = tmp;
}
bool inv_matlab(double *p,int N)
{
int is[MAX_MATLAB_N] = {0};
int js[MAX_MATLAB_N] = {0};
int f = 1;
double fDet = 1.0;
for (int k=0; k<N; k++)
{
// 1.
double tmp = 0.0;
double fmax = 0.0;
for (int i=k;i<N; i++)
{
for (int j=k; j<N; j++)
{
tmp = fabs(*(p + i*N + j));
if (tmp > fmax)
{
fmax = tmp;
is[k] = i;
js[k] = j;
}
}
}
if ( (fmax + 1.0) == 1.0)
{
//没有逆阵
printf("没有逆矩阵\n");
return false;
}
if (is[k] != k)
{
f = -f;
for (int i=0; i<N; i++)
{
//swap(p(k*N+i), p(is[k]*N+i));
swap(p +(k*N+i), p +(is[k]*N+i) );
}
}
if (is[k] != k)
{
f = -f;
for (int i=0; i<N; i++)
{
swap(p +(i*N+k), p +(i*N+is[k]) );
}
}
fDet *= p[k*N+k];
// 2.
p[k*N+k] = 1.0/p[k*N+k];
// 3.
for (int i=0; i<N; i++)
{
if (i!=k)
{
p[k*N+i] *= p[k*N+k];
}
}
// 4
for (int i=0; i<N; i++)
{
if (i!=k)
{
for (int j = 0; j<N; j++)
{
if (j!=k)
{
p[i*N+j] = p[i*N+j] - p[i*N+k] * p[k*N+j];
}
}
}
}
// 5
for (int i=0; i<N; i++)
{
if (i!=k)
{
p[i*N+k] *= -p[k*N+k];
}
}
}
for (int k=(N-1); k>=0; k--)
{
if (js[k] != k)
{
for (int i=0; i<N; i++)
{
swap(p +(k*N+i), p +(js[k]*N+i) );
}
}
if (is[k] != k)
{
for (int i=0; i<N; i++)
{
swap(p +(i*N+k), p +(i*N+is[k]));
}
}
}
return true;
}
void tra_matlab(double *p, double *pOut, int N)
{
for (int i=0; i<N; i++)
{
for (int j=0; j<N; j++)
{
pOut[i*N + j] = p[j*N + i];
}
}
}
void add_matlab(double *p1,double *p2, double *pOut, int N)
{
for (int i=0; i<N; i++)
{
for (int j=0; j<N; j++)
{
pOut[i*N + j] = p1[i*N + j] + p2[i*N + j];
}
}
}
void dec_matlab(double *p1,double *p2, double *pOut, int N)
{
for (int i=0; i<N; i++)
{
for (int j=0; j<N; j++)
{
pOut[i*N + j] = p1[i*N + j] - p2[i*N + j];
}
}
}
void mul_matlab(double *p1,double *p2, double *pOut, int N)
{
for (int i=0; i<N; i++)
{
for (int j=0; j<N; j++)
{
double sum = 0.0;
for (int m = 0; m<N; m++)
{
sum += p1[i*N + m]*p2[m*N + j];
}
pOut[i*N + j] = sum;
}
}
}
void mul_matlab(double *p1, int m1,int n1, double *p2, int m2,int n2, double *pOut)
{
if (n1 != m2)
{
printf("p1 的列数 必须等于 p2 的行数\n");
return;
}
int m = m1;
int n = n2;
for (int i = 0; i<m; i++)
{
for (int j = 0; j<n; j++)
{
double sum = 0.0;
for (int k = 0; k < n1; k++)
{
sum += p1[i*n1 + k] * p2[ k*n2 + j];
}
pOut[i*n + j] = sum;
}
}
}