zoukankan      html  css  js  c++  java
  • 矩阵类的python实现

    科学计算离不开矩阵的运算。当然,python已经有非常好的现成的库:numpy。

    我写这个矩阵类,并不是打算重新造一个轮子,只是作为一个练习,记录在此。

    :这个类的函数还没全部实现,慢慢在完善吧。

    全部代码:

      1 import copy
      2 
      3 class Matrix:
      4     '''矩阵类'''
      5     def __init__(self, row, column, fill=0.0):
      6         self.shape = (row, column)
      7         self.row = row
      8         self.column = column
      9         self._matrix = [[fill]*column for i in range(row)]
     10         
     11     # 返回元素m(i, j)的值:  m[i, j]
     12     def __getitem__(self, index):
     13         if isinstance(index, int):
     14             return self._matrix[index-1]
     15         elif isinstance(index, tuple):
     16             return self._matrix[index[0]-1][index[1]-1]
     17 
     18     # 设置元素m(i,j)的值为s:  m[i, j] = s
     19     def __setitem__(self, index, value):
     20         if isinstance(index, int):
     21             self._matrix[index-1] = copy.deepcopy(value)
     22         elif isinstance(index, tuple):
     23             self._matrix[index[0]-1][index[1]-1] = value
     24         
     25     def __eq__(self, N):
     26         '''相等'''
     27         # A == B
     28         assert isinstance(N, Matrix), "类型不匹配,不能比较"
     29         return N.shape == self.shape  # 比较维度,可以修改为别的
     30     
     31     def __add__(self, N):
     32         '''加法'''
     33         # A + B
     34         assert N.shape == self.shape, "维度不匹配,不能相加"
     35         M = Matrix(self.row, self.column)
     36         for r in range(self.row):
     37             for c in range(self.column):
     38                 M[r, c] = self[r, c] + N[r, c]
     39         return M
     40     
     41     def __sub__(self, N):
     42         '''减法'''
     43         # A - B
     44         assert N.shape == self.shape, "维度不匹配,不能相减"
     45         M = Matrix(self.row, self.column)
     46         for r in range(self.row):
     47             for c in range(self.column):
     48                 M[r, c] = self[r, c] - N[r, c]
     49         return M
     50     
     51     def __mul__(self, N):
     52         '''乘法'''
     53         # A * B (或:A * 2.0)
     54         if isinstance(N, int) or isinstance(N,float):
     55             M = Matrix(self.row, self.column)
     56             for r in range(self.row):
     57                 for c in range(self.column):
     58                     M[r, c] = self[r, c]*N
     59         else:
     60             assert N.row == self.column, "维度不匹配,不能相乘"
     61             M = Matrix(self.row, N.column)
     62             for r in range(self.row):
     63                 for c in range(N.column):
     64                     sum = 0
     65                     for k in range(self.column):
     66                         sum += self[r, k] * N[k, r]
     67                     M[r, c] = sum
     68         return M
     69     
     70     def __div__(self, N):
     71         '''除法'''
     72         # A / B
     73         pass
     74     def __pow__(self, k):
     75         '''乘方'''
     76         # A**k
     77         assert self.row == self.column, "不是方阵,不能乘方"
     78         M = copy.deepcopy(self)
     79         for i in range(k):
     80            M = M * self 
     81         return M
     82 
     83     def rank(self):
     84         '''矩阵的秩'''
     85         pass
     86     
     87     def trace(self):
     88         '''矩阵的迹'''
     89         pass
     90     
     91     def adjoint(self):
     92         '''伴随矩阵'''
     93         pass
     94     
     95     def invert(self):
     96         '''逆矩阵'''
     97         assert self.row == self.column, "不是方阵"
     98         M = Matrix(self.row, self.column*2)
     99         I = self.identity() # 单位矩阵
    100         I.show()#############################
    101         
    102         # 拼接
    103         for r in range(1,M.row+1):
    104             temp = self[r]
    105             temp.extend(I[r])
    106             M[r] = copy.deepcopy(temp)
    107         M.show()#############################
    108         
    109         # 初等行变换
    110         for r in range(1, M.row+1):
    111             # 本行首元素(M[r, r])若为 0,则向下交换最近的当前列元素非零的行
    112             if M[r, r] == 0:
    113                 for rr in range(r+1, M.row+1):
    114                     if M[rr, r] != 0:
    115                         M[r],M[rr] = M[rr],M[r] # 交换两行
    116                     break
    117 
    118             assert M[r, r] != 0, '矩阵不可逆'
    119             
    120             # 本行首元素(M[r, r])化为 1
    121             temp = M[r,r] # 缓存
    122             for c in range(r, M.column+1):
    123                 M[r, c] /= temp
    124                 print("M[{0}, {1}] /=  {2}".format(r,c,temp))
    125             M.show()
    126                 
    127             # 本列上、下方的所有元素化为 0
    128             for rr in range(1, M.row+1):
    129                 temp = M[rr, r] # 缓存
    130                 for c in range(r, M.column+1):
    131                     if rr == r:
    132                         continue
    133                     M[rr, c] -= temp * M[r, c]
    134                     print("M[{0}, {1}] -= {2} * M[{3}, {1}]".format(rr, c, temp,r))
    135                 M.show()    
    136             
    137         # 截取逆矩阵
    138         N = Matrix(self.row,self.column)
    139         for r in range(1,self.row+1):
    140             N[r] = M[r][self.row:]
    141         return N
    142             
    143         
    144     def jieti(self):
    145         '''行简化阶梯矩阵'''
    146         pass
    147         
    148         
    149     def transpose(self):
    150         '''转置'''
    151         M = Matrix(self.column, self.row)
    152         for r in range(self.column):
    153             for c in range(self.row):
    154                 M[r, c] = self[c, r]
    155         return M
    156     
    157     def cofactor(self, row, column):
    158         '''代数余子式(用于行列式展开)'''
    159         assert self.row == self.column, "不是方阵,无法计算代数余子式"
    160         assert self.row >= 3, "至少是3*3阶方阵"
    161         assert row <= self.row and column <= self.column, "下标超出范围"
    162         M = Matrix(self.column-1, self.row-1)
    163         for r in range(self.row):
    164             if r == row:
    165                 continue
    166             for c in range(self.column):
    167                 if c == column:
    168                     continue
    169                 rr = r-1 if r > row else r
    170                 cc = c-1 if c > column else c
    171                 M[rr, cc] = self[r, c]
    172         return M
    173     
    174     def det(self):
    175         '''计算行列式(determinant)'''
    176         assert self.row == self.column,"非行列式,不能计算"
    177         if self.shape == (2,2):
    178             return self[1,1]*self[2,2]-self[1,2]*self[2,1]
    179         else:
    180             sum = 0.0
    181             for c in range(self.column+1):
    182                 sum += (-1)**(c+1)*self[1,c]*self.cofactor(1,c).det()
    183             return sum
    184     
    185     def zeros(self):
    186         '''全零矩阵'''
    187         M = Matrix(self.column, self.row, fill=0.0)
    188         return M
    189     
    190     def ones(self):
    191         '''全1矩阵'''
    192         M = Matrix(self.column, self.row, fill=1.0)
    193         return M
    194     
    195     def identity(self):
    196         '''单位矩阵'''
    197         assert self.row == self.column, "非n*n矩阵,无单位矩阵"
    198         M = Matrix(self.column, self.row)
    199         for r in range(self.row):
    200             for c in range(self.column):
    201                 M[r, c] = 1.0 if r == c else 0.0
    202         return M
    203     
    204     def show(self):
    205         '''打印矩阵'''
    206         for r in range(self.row):
    207             for c in range(self.column):
    208                 print(self[r+1, c+1],end='  ')
    209             print()
    210     
    211 
    212 if __name__ == '__main__':
    213     m = Matrix(3,3,fill=2.0)
    214     n = Matrix(3,3,fill=3.5)
    215 
    216     m[1] = [1.,1.,2.]
    217     m[2] = [1.,2.,1.]
    218     m[3] = [2.,1.,1.]
    219     
    220     p = m * n
    221     q = m*2.1
    222     r = m**3
    223     #r.show()
    224     #q.show()
    225     #print(p[1,1])
    226     
    227     #r = m.invert()
    228     #s = r*m
    229     
    230     print()
    231     m.show()
    232     print()
    233     #r.show()
    234     print()    
    235     #s.show()
    236     print()
    237     print(m.det())
  • 相关阅读:
    POJ 1273:Drainage Ditches(EK 最大流)
    牛客假日团队赛6 H:Charm Bracelet (01背包)
    牛客假日团队赛6 F:Mud Puddles
    牛客假日团队赛6 E:对牛排序
    牛客假日团队赛6 D:迷路的牛
    牛客假日团队赛6 C:Bookshelf 2
    牛客假日团队赛6 B:Bookshelf
    牛客假日团队赛6 A:Card Stacking (模拟)
    UVA
    (转载)Mysql查找如何判断字段是否包含某个字符串
  • 原文地址:https://www.cnblogs.com/hhh5460/p/4314231.html
Copyright © 2011-2022 走看看