zoukankan      html  css  js  c++  java
  • 算法导论学习——分治矩阵乘法

    头文件 结构的定义

    stdafx.h

    // stdafx.h : 标准系统包含文件的包含文件,
    // 或是经常使用但不常更改的
    // 特定于项目的包含文件
    //
    
    #pragma once
    
    #include "targetver.h"
    
    #include <stdio.h>
    #include <tchar.h>
    
    
    
    // TODO:  在此处引用程序需要的其他头文件
    #include <iostream>
    
    using namespace std;
    
    
    #define Maxelem 10
    
    //求两个数的最大值
    int inline max(int a, int b){
    	return a >= b ? a : b;
    }
    //求三个数的最大值
    int inline max(int a, int b, int c){
    	return max(max(a, b), c);
    }
    
    //求min(2^x),ST. 2^x>=number。
    int inline L2n(int number){
    	if ((number & number - 1) == 0)
    	{
    		return number;
    	}
    	else {
    		return pow(2, (int)log2(number) + 1);
    	}
    }
    
    //定义矩阵的结构
    class Matrix{
    
    	
    public:
    	int data[Maxelem][Maxelem];
    	int M, N;
    	//以数组复制的方式构造矩阵对象,其中起始位置为0.
    	Matrix(int array[Maxelem][Maxelem], int m, int n){
    		M = m;
    		N = n;
    		for (int i = 0; i < m; i++){
    			for (int j = 0; j < n; j++){
    				this->data[i][j] = array[i][j];
    			}
    		}
    	}
    	
    	//以数组复制的方式构造矩阵对象,以strat1,end1,strat2,end2为区间
    	Matrix(int array[Maxelem][Maxelem], int start1, int end1, int start2, int end2){
    		M = end1 - start1 + 1;
    		N = end2 - start2 + 1;
    		int p = 0, q = 0;
    		for (int i = 0; i < M; i++){
    			for (int j = 0; j < N; j++){
    				data[i][j] = array[start1 + i][start2 +j];
    			}
    			
    		}
    			
    	}
    	//仅构造矩阵,不填充数据
    	Matrix(int m, int n) :M(m), N(n){}
    
    
    	/*
    	 打印输出
    	*/
    	void print() const {
    		for (int i = 0; i <M; i++){
    			for (int j = 0; j < N; j++){
    				cout << data[i][j] << " ";
    			} 
    			cout << endl;
    
    		}
    	}
    	//为矩阵补充0,使其成为标准方阵
    		void fill(){
    		
    		if (!(M == N && ((M & M - 1) == 0))){
    			int n = L2n(max(M, N));
    			
    			for (int i = 0; i < n; i++){
    				for (int j = 0; j < n; j++){
    					data[i][j] = (i < M&&j < N ? data[i][j] : 0);
    				}
    			}
    			M = n;
    			N = n;
    			
    		}
    	}
    
    
    	/*重载二维运算符[][]*/
    	int * const operator[](const int i)
    	{
    		return data[i];
    	}
    
    	Matrix friend operator +(Matrix m1, Matrix m2){
    
    	
    		Matrix op = Matrix(m1.M, m1.N);
    		for (int i = 0; i < m1.M; i++)
    		{
    			for (int j = 0; j < m1.M; j++){
    				op[i][j] = m1[i][j] + m2[i][j];
    			}
    		}
    		return op;
    	}
    	Matrix friend operator -(Matrix m1, Matrix m2){
    
    
    		Matrix op = Matrix(m1.M, m1.N);
    		for (int i = 0; i < m1.M; i++)
    		{
    			for (int j = 0; j < m1.M; j++){
    				op[i][j] = m1[i][j] - m2[i][j];
    			}
    		}
    		return op;
    	}
    	/*
    	将计算过程中补充的0清除。计算完毕后才能用的方法,不加也能得到结果,不过行列数不对。
    	*/
    	void clean(int m, int n){
    		M = m;
    		N = n;
    	}
    };
    

      算法的实现:

    // strassenAlgorithm.cpp : 定义控制台应用程序的入口点。
    //
    
    #include "stdafx.h"
    
    
    
    
    //该方法只能相乘最简单的2*2矩阵。
    Matrix mutilSimple(Matrix A, Matrix B){
    	int a = A[0][0],b=A[0][1],c=A[1][0],d=A[1][1];
    	int e = B[0][0], f = B[0][1], g = B[1][0], h = B[1][1];
    
    	int p1 = a*(f - h);
    	int p2 = (a + b)*h;
    	int p3 = (c + d)*e;
    	int p4 = d*(g - e);
    	int p5 = (a + d)*(e + h);
    	int p6 = (b - d)*(g + h);
    	int p7 = (a - c)*(e + f);
    	
    	Matrix returnValue = Matrix(2, 2);
    	returnValue[0][0] = p5 + p4 - p2 + p6;
    	returnValue[0][1] = p1 + p2;
    	returnValue[1][0] = p3 + p4;
    	returnValue[1][1] = p1 + p5 - p3 - p7;
    
    	return returnValue;
    }
    
    //矩阵乘法 必须用了fill方法才能相乘
    Matrix mutilMerge(Matrix A,Matrix B){
    	if (A.M == 2 && B.M == 2){
    		return mutilSimple(A, B);
    	}
    	int k = A.M;
    	Matrix 
    		a = Matrix(A.data, 0, k / 2 - 1, 0, k / 2 - 1),
    		b = Matrix(A.data, 0, k / 2 - 1, k / 2, k - 1),
    		c = Matrix(A.data, k / 2, k - 1, 0, k / 2 - 1),
    		d = Matrix(A.data, k / 2, k - 1, k / 2, k - 1),
    	
    		e = Matrix(B.data, 0, k / 2 - 1, 0, k / 2 - 1),
    		f = Matrix(B.data, 0, k / 2 - 1, k / 2, k - 1),
    		g = Matrix(B.data, k / 2, k - 1, 0, k / 2 - 1),
    		h = Matrix(B.data, k / 2, k - 1, k / 2, k - 1),
    		op = Matrix(k, k),
    	
    		p1 = mutilMerge(a, f - h),
    		p2 = mutilMerge(a + b, h),
    		p3 = mutilMerge(c + d, e),
    		p4 = mutilMerge(d, g - e),
    		p5 = mutilMerge(a + d, e + h),
    		p6 = mutilMerge(b - d, g + h),
    		p7 = mutilMerge(a - c, e + f),
    	
    		
    		op1=p5+p4-p2+p6,
    		op2=p1+p2,
    		op3=p3+p4,
    		op4 = p1 + p5 - p3 - p7; 
    
    
    
    	int x1 = 0, y1 = 0, x2 = 0, y2 = 0, x3 = 0,y3=0,x4=0, y4 = 0; //4个变量的游标
    	int u = 0, v = 0;
    	for (int i = 0; i < k; i++)
    	{
    		for (int j = 0; j < k; j++){
    		
    			if (i >= 0 && i <= k / 2 - 1 && j >= 0 && j <= k / 2 - 1){
    				op[i][j] = op1[x1][y1];
    				y1++;
    				if (y1 == op1.M) { y1 = 0; x1++; }
    			}
    			if (i >= 0 && i <= k / 2 - 1 && j >= k / 2 && j <= k - 1){
    				op[i][j] = op2[x2][y2];
    				y2++;
    				if (y2 == op2.M) { y2 = 0; x2++; }
    			}
    			if (i >= k/2 && i <= k - 1 && j >= 0 && j <= k / 2 - 1){
    
    				op[i][j] = op3[x3][y3];
    				y3++;
    				if (y3 == op3.M) { y3 = 0; x3++; }
    			}
    			
    			if (i >= k / 2 && i <= k - 1 && j >= k/2 && j <= k - 1){
    				op[i][j] = op4[x4][y4];
    				y4++;
    				if (y4 == op4.M) { y4 = 0; x4++; }
    			}
    
    		}
    	}
    
    
    	
    	return op;
    	
    }
    Matrix zeroclear( Matrix result,int M,int N){
    	result.M = M;
    	result.N = N;
    	return result;
    }
    int _tmain(int argc, _TCHAR* argv[])
    {
    
    	int matrixA[Maxelem][Maxelem] = {
    		{ 10, 3, 3 ,7,4},
    		{ 5, 3, 8,2,1 },
    		{ -2, 3, 7, 5, 2 }, 
    		{1,10,-2,1,8},
    		{3,3,3,3,3}
    	};
    
    	int matrixB[Maxelem][Maxelem] = {
    		{ -4, 6, 1,2,1 },
    		{ 9, 10, 8,0,3 },
    		{ 2, 3, -7 ,-1,-1},
    		{ 1, -6, 2, 1, 7 }, 
    		{1,2,3,4,5}
    	};
    
    	Matrix ma = Matrix(matrixA,5,5);
    	Matrix mb = Matrix(matrixB,5,5);
    
    	
    	cout << "A="<<endl;
    	
    	ma.print();
    	cout << "B="<<endl;
    	mb.print();
    
    	ma.fill();
    	mb.fill();
    	Matrix result = mutilMerge(ma, mb);
    	
    	cout << "A×B=" << endl;
    	result.clean(5, 5);
    	result.print();
    	cout << "B×A=" << endl;
    	result = mutilMerge(mb, ma);
    	result.clean(5, 5);
    	result.print();
    	system("pause");
    
    	return 0;
    }
    

      

  • 相关阅读:
    slf4j日志框架绑定机制
    Btrace使用入门
    JVM反调调用优化,导致发生大量异常时log4j2线程阻塞
    [转载]Javassist 使用指南(三)
    [转载]Javassist 使用指南(二)
    [转载]Javassist 使用指南(一)
    数组升序排序的方法Arrays.sort();的应用
    copyOfRange的应用
    copyOf数组复制方法的使用(数组扩容练习)
    binarySearch(int[] a,int fromIndex,int toIndex, int key)的用法
  • 原文地址:https://www.cnblogs.com/xcr1234/p/4875451.html
Copyright © 2011-2022 走看看