今天老板出差了,闲来无事自己写了一个矩阵类作为休息,想实现像matlab一样强大的功能,今天只是实现了最基本的部分,以后还得多多改进
头文件:
1 #ifndef _MATRIX_H_ 2 #define _MATRIX_H_ 3 4 template <class DataType> 5 class Matrix 6 { 7 public: 8 Matrix(); //constructor 9 Matrix(int r, int c); //constructor the matrix of r*c 10 ~Matrix(); //deconstructor 11 void SetMem(int r, int c, DataType v); //set the element matrix[r][c] 12 DataType GetMem(int r, int c) const; //get the element matrix[r][c] 13 void SetRow(int r, DataType *a, int len); //set the value of the r row 14 void SetCol(int c, DataType *a, int len); //set the value of the c col 15 void Zeros(); //set all element 0 16 void Ones(); //set all element 1 17 18 Matrix<DataType>* GetRow(int r); //get the r row 19 Matrix<DataType>* GetCol(int c); //get the c column 20 21 int GetRows(); //get the number of rows 22 int GetColumns(); //get the number of columns 23 24 Matrix<DataType>& operator + (Matrix<DataType> &rhl); //matrix + matrix 25 Matrix<DataType>& operator + (DataType v); //matrix + DataType 26 Matrix<DataType>& operator - (Matrix<DataType> &rhl); //matrix - matrix 27 Matrix<DataType>& operator - (DataType v); //matrix - DataType 28 Matrix<DataType>& operator * (Matrix<DataType> &rhl); //matrix .* matrix 29 Matrix<DataType>& operator * (DataType v); //matrix * DataType 30 Matrix<DataType>& operator / (Matrix<DataType> &rhl); //matrix ./ matrix 31 Matrix<DataType>& operator / (DataType v); //matrix / DataType 32 Matrix<DataType>& operator = (Matrix<DataType> &rhl); //matrix = matrix 33 Matrix<DataType>& operator = (DataType v); //matrix = DataType 34 bool operator == (Matrix<DataType> &rhl); //return true if the every element is equal,else false 35 DataType operator () (int x, int y); //the sub index 36 Matrix<DataType>* operator () (int x, char flag); //the sub index,flag=='R',return x row;else 'C',return x column 37 Matrix<DataType>& MatrixMul(Matrix<DataType>& lhs,Matrix<DataType>& rhs); 38 39 void Show(); //display the matrix 40 private: 41 DataType* data; //pointer to data 42 int rows; //row 43 int cols; //column 44 inline void DeepCopy(Matrix<DataType>& org); //copy every element of the matrix 45 }; 46 47 48 #endif
实现源文件:
1 #include "Matrix.h" 2 #include <iostream> 3 using namespace std; 4 5 template <class DataType> 6 Matrix<DataType>::Matrix() 7 { 8 rows=1; 9 cols=1; 10 data=new DataType[rows*cols]; 11 memset(data,0,rows*cols*sizeof(DataType)); 12 } 13 14 template <class DataType> 15 Matrix<DataType>::Matrix(int r, int c) 16 { 17 rows=r; 18 cols=c; 19 data=new DataType[rows*cols]; 20 memset(data,0,rows*cols*sizeof(DataType)); 21 } 22 23 template <class DataType> 24 Matrix<DataType>::~Matrix() 25 { 26 delete[] data; 27 } 28 29 template <class DataType> 30 void Matrix<DataType>::Show() 31 { 32 for (int i=0;i<rows;i++) 33 { 34 for (int j=0;j<cols;j++) 35 { 36 cout<<data[i*cols+j]<<" "; 37 if(j==cols-1) 38 cout<<endl; 39 } 40 } 41 } 42 43 template <class DataType> 44 void Matrix<DataType>::SetMem(int r, int c, DataType v) 45 { 46 if(r>rows||c>cols) 47 return; 48 data[(r-1)*cols+(c-1)]=v; 49 } 50 template <class DataType> 51 DataType Matrix<DataType>::GetMem(int r, int c) const 52 { 53 if(r>rows||c>cols) 54 return -1; 55 return data[(r-1)*cols+(c-1)]; 56 } 57 58 template <class DataType> 59 void Matrix<DataType>::SetRow(int r, DataType *a, int len) 60 { 61 if (r>rows||len!=cols) 62 return; 63 64 for (int i=0;i<cols;i++) 65 { 66 data[(r-1)*cols+i]=a[i]; 67 } 68 } 69 70 template <class DataType> 71 void Matrix<DataType>::SetCol(int c, DataType *a, int len) 72 { 73 if(c>cols||len!=rows) 74 return; 75 for (int i=0;i<rows;i++) 76 { 77 data[i*cols+(c-1)]=a[i]; 78 } 79 } 80 81 template <class DataType> 82 Matrix<DataType>* Matrix<DataType>::GetRow(int r) 83 { 84 if (r>rows) 85 { 86 exit(0); 87 } 88 else 89 { 90 Matrix<DataType> *tmp=new Matrix<DataType>(1,this->cols); 91 for (int i=1;i<=this->cols;i++) 92 { 93 DataType tm=0; 94 tm=this->GetMem(r,i); 95 tmp->SetMem(1,i,tm); 96 } 97 return tmp; 98 } 99 } 100 101 template <class DataType> 102 Matrix<DataType>* Matrix<DataType>::GetCol(int c) 103 { 104 if (c>cols) 105 { 106 exit(0); 107 } 108 else 109 { 110 Matrix<DataType> *tmp=new Matrix<DataType>(1,this->rows); 111 DataType tm=0; 112 for (int i=1;i<=this->rows;i++) 113 { 114 tm=this->GetMem(i,c); 115 tmp->SetMem(1,i,tm); 116 } 117 return tmp; 118 } 119 } 120 121 template <class DataType> 122 int Matrix<DataType>::GetColumns() 123 { 124 return cols; 125 } 126 127 template <class DataType> 128 int Matrix<DataType>::GetRows() 129 { 130 return rows; 131 } 132 133 template <class DataType> 134 Matrix<DataType>& Matrix<DataType>::operator + (Matrix<DataType>& rhl) 135 { 136 if (rows!=rhl.GetRows()||cols!=rhl.GetColumns()) 137 { 138 exit(0); 139 } 140 else 141 { 142 for (int i=1;i<=rows;i++) 143 { 144 for (int j=1;j<=cols;j++) 145 { 146 this->SetMem(i,j,this->GetMem(i,j)+rhl.GetMem(i,j)); 147 } 148 } 149 return *this; 150 } 151 } 152 153 template <class DataType> 154 Matrix<DataType>& Matrix<DataType>::operator - (Matrix<DataType>& rhl) 155 { 156 if (rows!=rhl.GetRows()||cols!=rhl.GetColumns()) 157 { 158 exit(0); 159 } 160 else 161 { 162 for (int i=1;i<=rows;i++) 163 { 164 for (int j=1;j<=cols;j++) 165 { 166 this->SetMem(i,j,this->GetMem(i,j)-rhl.GetMem(i,j)); 167 } 168 } 169 return *this; 170 } 171 } 172 173 template <class DataType> 174 Matrix<DataType>& Matrix<DataType>::operator * (Matrix<DataType>& rhs) 175 { 176 177 if (rows!=rhs.GetRows()||cols!=rhs.GetColumns()) 178 { 179 exit(0); 180 } 181 for (int i=1;i<=rows;i++) 182 { 183 for (int j=1;j<=cols;j++) 184 { 185 this->SetMem(i,j,this->GetMem(i,j)*rhs.GetMem(i,j)); 186 } 187 } 188 return *this; 189 } 190 191 192 template <class DataType> 193 Matrix<DataType>& Matrix<DataType>::operator / (Matrix<DataType>& rhl) 194 { 195 if (rows!=rhl.GetRows()||cols!=rhl.GetColumns()) 196 { 197 exit(0); 198 } 199 else 200 { 201 for (int i=1;i<=rows;i++) 202 { 203 for (int j=1;j<=cols;j++) 204 { 205 this->SetMem(i,j,this->GetMem(i,j)/rhl.GetMem(i,j)); 206 } 207 } 208 return *this; 209 } 210 } 211 212 template <class DataType> 213 void Matrix<DataType>::DeepCopy(Matrix<DataType>& org) 214 { 215 if(cols!=org.GetColumns()||rows!=org.GetRows()) 216 return; 217 for (int i=1;i<=rows;i++) 218 { 219 for (int j=1;j<=cols;j++) 220 { 221 this->SetMem(i,j,org.GetMem(i,j)); 222 } 223 } 224 } 225 226 template <class DataType> 227 Matrix<DataType>& Matrix<DataType>::operator = (Matrix<DataType>& rhl) 228 { 229 if(this==&rhl) 230 return *this; 231 DeepCopy(rhl); 232 return *this; 233 } 234 235 template <class DataType> 236 void Matrix<DataType>::Zeros() 237 { 238 for (int i=1;i<=rows;i++) 239 { 240 for (int j=1;j<=cols;j++) 241 { 242 this->SetMem(i,j,0.0); 243 } 244 } 245 } 246 247 template <class DataType> 248 void Matrix<DataType>::Ones() 249 { 250 for (int i=1;i<=rows;i++) 251 { 252 for (int j=1;j<=cols;j++) 253 { 254 this->SetMem(i,j,1.0); 255 } 256 } 257 } 258 259 template <class DataType> 260 Matrix<DataType>& Matrix<DataType>::operator + (DataType v) 261 { 262 for (int i=1;i<=rows;i++) 263 { 264 for (int j=1;j<=cols;j++) 265 { 266 this->SetMem(i,j,this->GetMem(i,j)+v); 267 } 268 } 269 return *this; 270 } 271 272 template <class DataType> 273 Matrix<DataType>& Matrix<DataType>::operator - (DataType v) 274 { 275 for (int i=1;i<=rows;i++) 276 { 277 for (int j=1;j<=cols;j++) 278 { 279 this->SetMem(i,j,this->GetMem(i,j)-v); 280 } 281 } 282 return *this; 283 } 284 285 template <class DataType> 286 Matrix<DataType>& Matrix<DataType>::operator * (DataType v) 287 { 288 for (int i=1;i<=rows;i++) 289 { 290 for (int j=1;j<=cols;j++) 291 { 292 this->SetMem(i,j,this->GetMem(i,j)*v); 293 } 294 } 295 return *this; 296 } 297 298 template <class DataType> 299 Matrix<DataType>& Matrix<DataType>::operator / (DataType v) 300 { 301 for (int i=1;i<=rows;i++) 302 { 303 for (int j=1;j<=cols;j++) 304 { 305 this->SetMem(i,j,this->GetMem(i,j)/v); 306 } 307 } 308 return *this; 309 } 310 311 template <class DataType> 312 Matrix<DataType>& Matrix<DataType>::operator = (DataType v) 313 { 314 315 for (int i=1;i<=rows;i++) 316 { 317 for (int j=1;j<=cols;j++) 318 { 319 this->SetMem(i,j,v); 320 } 321 } 322 return *this; 323 } 324 325 326 template <class DataType> 327 bool Matrix<DataType>::operator == (Matrix<DataType> &rhl) 328 { 329 if (rows!=rhl.GetRows()||cols!=rhl.GetColumns()) 330 { 331 return false; 332 } 333 else 334 { 335 int count=0; 336 for (int i=1;i<=rows;i++) 337 { 338 for (int j=1;j<=cols;j++) 339 { 340 if(GetMem(i,j)==rhl.GetMem(i,j)) 341 count++; 342 } 343 } 344 if(count==rows*cols) 345 return true; 346 else 347 return false; 348 } 349 } 350 351 template <class DataType> 352 DataType Matrix<DataType>::operator () (int x, int y) 353 { 354 return GetMem(x,y); 355 } 356 357 template <class DataType> 358 Matrix<DataType>* Matrix<DataType>::operator () (int x, char flag) 359 { 360 if (flag=='R') 361 { 362 return GetRow(x); 363 } 364 else 365 { 366 return GetCol(x); 367 } 368 } 369 //before call the function, must use operator new to apply some storage for the object 370 template <class DataType> 371 Matrix<DataType>& Matrix<DataType>::MatrixMul(Matrix<DataType>& lhs, Matrix<DataType>& rhs) 372 { 373 if (lhs.GetColumns()!=rhs.GetRows()) 374 { 375 exit(0); 376 } 377 for (int i=1;i<=lhs.GetRows();i++) 378 { 379 for (int j=1;j<=rhs.GetColumns();j++) 380 { 381 DataType tm=0; 382 for (int k=1;k<=lhs.GetColumns();k++) 383 { 384 tm+=GetMem(i,k)*rhs.GetMem(k,j); 385 } 386 this->SetMem(i,j,tm); 387 } 388 } 389 return *this; 390 }
测试文件:
1 #include <iostream> 2 #include "Matrix.cpp" //use the template, so must include the cpp file 3 using namespace std; 4 5 void main() 6 { 7 Matrix<double> dMat(2,4); 8 dMat.SetMem(2,2,10.0); 9 dMat.SetMem(2,4,5.0); 10 dMat.Show(); 11 12 Matrix<double> dMat1(4,2); 13 dMat1.SetMem(2,1,10.0); 14 dMat1.SetMem(2,2,3.0); 15 dMat1.Show(); 16 17 Matrix<double> *pMat=new Matrix<double>(4,4); 18 pMat->MatrixMul(dMat1,dMat); 19 pMat->Show(); 20 }
以后再继续完善这个类,现在先去吃饭了。(*^__^*) 嘻嘻……