zoukankan      html  css  js  c++  java
  • 矩阵乘法的顺序安排问题 Python简单实现

    矩阵乘法的顺序安排问题

    问题背景

    设矩阵 AB 大小分别 (p imes q) , (q imes r) ,则矩阵乘积 AB 需要做的标量乘法次数为 (p imes q imes r) 。我们知道矩阵的乘法运算是不可交换的,但它是可结合的。因此对于多个矩阵的连乘,我们可以以任意顺序添加括号改变其中相邻矩阵乘法的优先级。不同计算顺序下总的标量乘法运算次数是不同的,我们的目标是找到一个最优的矩阵乘法计算顺序。

    给定矩阵乘法序列 (A_1, A_2, ..., A_n),将乘法序列以第 (i) 个矩阵分为前后两部分,则方案数为前后两部分方案数之积。因此乘法计算的顺序个数为

    [T(n) = Sigma_{i=1}^{n-1} T(i) cdot T(n-i) ]

    T(n) 的解为 Catalan数,这里不加证明给出结果为

    [T(n) = ext{C}_{2n}^n - ext{C}_{2n}^{n+1} = frac { ext{C}_{2n}^n} {n+1} ]

    由此可见的矩阵乘法顺序个数为问题规模 (n) 的指数级,显然通过枚举找到最优的乘法顺序是不合适的。

    暴力算法

    首先还是试探一下如何用最朴素的方式解决。

    (M_{i, j}) 表示 第 (i) 个矩阵到第 (j) 个矩阵的最少乘法运算次数,用数学化的语言表达我们的目标,即

    [M_{1, n} = min_i { M_{1, i} + M_{i+1, n} + p imes q imes r } ]

    其中 p、q、r为最后两个矩阵的大小。

    代码很容易实现:

    def minMatrixMultiplication(Mats):
    	"""
    	:param Mats: Mat类型的list
    	:return:	 矩阵乘法的最小乘法次数,及对应的括号位置
    	"""
        
    	
    	if len(Mats)==1:
    		return 0, '[%d,%d]' % (Mats[0].n, Mats[0].m)
    
    	import math
    	minCost = math.inf
    	bestSeq = ''        # 记录添加的括号位置
    	for i in range(0, len(Mats)-1):
    		leftCost, leftSeq = minMatrixMultiplication(Mats[:i+1])
    		rightCost, rightSeq = minMatrixMultiplication(Mats[i+1:])
    		tmpCost = leftCost + rightCost +  Mats[0].n * Mats[i].m * Mats[-1].m
    
    		if tmpCost < minCost:
    			minCost = tmpCost
    			bestSeq = '(' + leftSeq + '*' + rightSeq + ')'
    
    	return minCost, bestSeq
    

    测试用的矩阵类型Mat定义如下:

    class Mat:
    	def __init__(self, mat=None):
    		if mat and isinstance(mat[0], list):
    			self.mat = mat
    			self.n = len(mat)
    			self.m = len(mat[0])
    		else:
    			self.mat = [[]]
    			self.n = 0
    			self.m = 0
    
    	def __init__(self, n, m):
    		self.mat = [[]]
    		self.n = n
    		self.m = m
    

    以上算法的函数调用次数 (f(n) = 1 + f(1)+f(n-1) + f(2)+f(n-2) + ... + f(n-1)+f(1))

    容易验证得到(f(n) = 3^n), 即该算法的复杂度为 O((3^n)),这是不可接受的。

    记忆化

    分析一番可以发现,对于矩阵序列 i~j 之间乘法的最优结果 (M_{i, j}) 只有 ( ext{C}_n^2) 种,那么上述代码的中间很多段都进行了重复计算。如果把中间得到的答案记录下来,可以大大减少计算量。

    在不改变上述算法的框架下,将 i~j 之间的结果 (M_{i, j}) 定义Python嵌套的内部函数。新增了变量 invokeCnt 统计递归函数需要重新计算 (M_{i, j}) 的次数。

    def minMatrixMultiplication2(Mats):
    	siz = len(Mats) + 1
    	# 血的教训:不要使用下面的方法定义二维数组
    	# minCostMem = [[-1]*siz]*siz
    	# bestSeqMem = [['']*siz]*siz
    	minCostMem = [[-1]*siz for i in range(siz)]
    	bestSeqMem = [['']*siz for i in range(siz)]
    
    	invokeCnt = 0  # 统计递归函数重新执行次数
    	def helper(s, t):
    		if s==t:
    			return 0, '[%d,%d]' % (Mats[s].n, Mats[s].m)
    
    		if minCostMem[s][t]!=-1:
    			return minCostMem[s][t], bestSeqMem[s][t]
    
    		nonlocal invokeCnt
    		invokeCnt += 1
    
    		import math
    		minCost = math.inf
    		bestSeq = ''
    		for i in range(s, t):
    			leftCost, leftSeq = helper(s, i)
    			rightCost, rightSeq = helper(i+1, t)
    			tmpCost = leftCost + rightCost +  Mats[s].n * Mats[i].m * Mats[t].m
    			if tmpCost < minCost:
    				minCost = tmpCost
    				bestSeq = '(' + leftSeq + '*' + rightSeq + ')'
    
    		minCostMem[s][t] = minCost
    		bestSeqMem[s][t] = bestSeq
    
    		return minCost, bestSeq
    
    	return helper(0, len(Mats)-1), invokeCnt
    

    动态规划

    (待补充。。。)

    运行对比

    Mats = [Mat(2,3), Mat(3,5), Mat(5,8), Mat(8,2), Mat(2,3), Mat(3,2), Mat(2,5), Mat(5, 3)]
    print(minMatrixMultiplication(Mats))
    # (184, '((([2,3]*([3,5]*([5,8]*[8,2])))*([2,3]*[3,2]))*([2,5]*[5,3]))')
    # 调用次数 3^8 = 2187
    print(minMatrixMultiplication2(Mats))
    # ((184, '((([2,3]*([3,5]*([5,8]*[8,2])))*([2,3]*[3,2]))*([2,5]*[5,3]))'), 28)
    

    注意事项

    Python 定义二维矩阵,千万不要使用注释写法。调试了很久才发现问题。 T^T
    正确的写法为

    • matrix = [[0]*m for i in range(n)]
    • 或使用numpy库
      import numpy
      matrix = numpy.zeros((n, m))

    原因可以简单理解为

    n = 5
    m = 3
    matrix = [[0]*m]*n
    
    # 相当于
    """
    array = [0 0 0]
    matrix = [array]*5
    # matrix内的5个元素都是同一个列表引用
    # 当使用 matrix[3][2] = 1 赋值
    # 则 array[2] = 1
    # 所以 matrix[0~4][2]都为 1
    """
  • 相关阅读:
    Tinkoff Challenge
    Uva 12298 超级扑克2
    BZOJ 1031 字符加密
    HDU 4944 逆序数对
    51nod 1215 数组的宽度
    LA 3126 出租车
    LA 3415 保守的老师
    51nod 1275 连续子段的差异
    Uva 11419 我是SAM
    LA 4043 最优匹配
  • 原文地址:https://www.cnblogs.com/izcat/p/12549542.html
Copyright © 2011-2022 走看看