zoukankan      html  css  js  c++  java
  • Python+MapReduce实现矩阵相乘

    算法原理

    map阶段

    在map阶段,需要做的是进行数据准备。把来自矩阵A的元素aij,标识成p条<key, value>的形式,key="i,k",(其中k=1,2,...,p),value="a:j,aij";把来自矩阵B的元素bij,标识成m条<key, value>形式,key="k,j"(其中k=1,2,...,m),value="b:i,bij"。

    经过处理,用于计算cij需要的a、b就转变为有相同key("i,j")的数据对,通过value中"a:"、"b:"能区分元素是来自矩阵A还是矩阵B,以及具体的位置(在矩阵A的第几列,在矩阵B的第几行)。

    shuffle阶段  

    这个阶段是Hadoop自动完成的阶段,具有相同key的value被分到同一个Iterable中,形成<key,Iterable(value)>对,再传递给reduce。

    reduce阶段

    通过map数据预处理和shuffle数据分组两个阶段,reduce阶段只需要知道两件事就行:

    <key,Iterable(value)>对经过计算得到的是矩阵C的哪个元素?因为map阶段对数据的处理,key(i,j)中的数据对,就是其在矩阵C中的位置,第i行j列。
    Iterable中的每个value来自于矩阵A和矩阵B的哪个位置?这个也在map阶段进行了标记,对于value(x:y,z),只需要找到y相同的来自不同矩阵(即x分别为a和b)的两个元素,取z相乘,然后加和即可。

    过程如下图所示:

    算法实现

    mapper.py

     1 #!/usr/bin/env python3
     2 import sys
     3 
     4 flag = 0       # 0表示输入A、B矩阵信息,1表示处理A矩阵,2表示处理B矩阵
     5 row_a, col_a, row_b, col_b = 0, 0, 0, 0  # A、B矩阵shape
     6 current_row = 1  # 记录现在处理矩阵的第几行
     7 
     8 
     9 def read_input():
    10     for lines in sys.stdin:
    11         yield lines
    12 
    13 
    14 if __name__ == '__main__':
    15     for line in read_input():
    16         if line.count('
    ') == len(line):    # 去空行
    17             pass
    18         data = line.strip().split('	')
    19 
    20         if flag == 0:
    21             flag = 1
    22             row_a = int(data[0])
    23             col_a = int(data[1])
    24             row_b = int(data[2])
    25             col_b = int(data[3])
    26             if row_a == 0 or row_b == 0 or col_a == 0 or col_b ==0 or col_a != row_b:
    27                 print("矩阵输入错误!")
    28                 break
    29 
    30         elif flag == 1:
    31             for i in range(col_b):
    32                 for j in range(col_a):
    33                     print("%s,%s	A:%s,%s" % (current_row, i+1, j+1, data[j]))
    34             current_row += 1
    35             if current_row > row_a:
    36                 flag = 2
    37                 current_row = 1
    38 
    39         elif flag == 2:
    40             for i in range(row_a):
    41                 for j in range(col_b):
    42                     print("%s,%s	B:%s,%s" % (i+1, j+1, current_row, data[j]))
    43             current_row += 1

    reducer.py

    这是我一开始所写的版本。

     1 #!/usr/bin/env python3
     2 import sys
     3 
     4 
     5 last, now = None, None
     6 s = 0.0
     7 count = 0
     8 matrix_a, matrix_b = {}, {}
     9 
    10 
    11 def read_input():
    12     for lines in sys.stdin:
    13         yield lines
    14 
    15 
    16 if __name__ == '__main__':
    17     for line in read_input():
    18         if line.count('
    ') == len(line):    # 去空行
    19             pass
    20         data = line.strip().split('	')
    21         now = data[0]
    22         if last is None:
    23             last = now
    24             count = 0
    25         elif last != now:
    26             for key in matrix_a:
    27                 s += float(matrix_a[key])*float(matrix_b[key])
    28             print("%s	%s" % (last, s))
    29             s = 0.0
    30             count = 0
    31             last = now
    32 
    33         value1 = data[1][0]
    34         value2 = data[1].split(':')[1].split(',')[0]
    35         value3 = data[1].split(',')[1]
    36         if value1 == 'A':
    37             count += 1
    38             matrix_a[value2] = value3
    39         else:
    40             matrix_b[value2] = value3
    41 
    42     for key in matrix_a:
    43         s += float(matrix_a[key])*float(matrix_b[key])
    44     print("%s	%s" % (last, s))

     后来借鉴参考了别人的代码后,学习了groupby,下面的代码就简洁多了。

     1 #!/usr/bin/env python3
     2 import sys
     3 from itertools import groupby
     4 from operator import itemgetter
     5 
     6 
     7 def read_input(splitstr):
     8     for line in sys.stdin:
     9         line = line.strip()
    10         if len(line) == 0:
    11             continue
    12         yield line.split(splitstr)
    13 
    14 
    15 if __name__ == '__main__':
    16     data = read_input('	')
    17     lstg = (groupby(data, itemgetter(0)))
    18     try:
    19         for flag, group in lstg:
    20             matrix_a, matrix_b = {}, {}
    21             total = 0.0
    22             for element, g in group:
    23                 matrix = g.split(':')[0]
    24                 pos = g.split(':')[1].split(',')[0]
    25                 value = g.split(',')[1]
    26                 if matrix == 'A':
    27                     matrix_a[pos] = value
    28                 else:
    29                     matrix_b[pos] = value
    30             for key in matrix_a:
    31                 total += float(matrix_a[key]) * float(matrix_b[key])
    32             print("%s	%s" % (flag, total))
    33     except Exception:
    34         pass

    算法运行

    执行结果为:

    参考:

    [1] 用MapReduce实现矩阵乘法

    [2] python版mapreduce矩阵相乘

    [3] MapReduce实现矩阵乘法

  • 相关阅读:
    算法学习概述(2016.6)
    java异常和错误类总结(2016.5)
    java string 细节原理分析(2016.5)
    MySQL 5.7.18 解压版安装
    Struts2的<s:date>标签使用详解[转]
    jprofile查看hprof文件[转]
    iBatis的Settings节点参数详解[转]
    window.open、window.showModalDialog和window.showModelessDialog 的区别[转]
    oracle 字典表查询
    oracle 表空间操作
  • 原文地址:https://www.cnblogs.com/zyb993963526/p/10586248.html
Copyright © 2011-2022 走看看