zoukankan      html  css  js  c++  java
  • dp方法论——由矩阵相乘问题学习dp解题思路

    前篇戳:dp入门——由分杆问题认识动态规划

    导语

    刷过一些算法题,就会十分珍惜“方法论”这种东西。Leetcode上只有题目、讨论和答案,没有方法论。往往答案看起来十分切中要害,但是从看题目到得到思路的那一段,就是绕不过去。楼主有段时间曾把这个过程归结于智商和灵感的结合,直到有天为了搞懂Leetcode上一位老兄的题型总结,花两天时间学习了回溯法,突然有种惊为天人的感觉——原来真正掌握一个算法是应该触类旁通的,而不是将题中一个细节换掉就又成了新题……

    掌握方法论绝对是一种很爽的感觉。看起来好像很花费时间,其实是一种“因为慢,所以快”的方法。以前可能你学习一个dp题目要大半天;当你花了半个周时间,学会了dp的套路,你会发现,有些medium的dp题甚至不需要半个小时就能做完,而且从头到尾不需提示,全靠自己!

    方法论

    那么,怎么从一个看起来毫无头绪的问题出发,找到解题的思路并用dp将问题解出来呢?本文以矩阵相乘问题为例,给出dp问题的一般解题思路。

    当然,按照思路解题的前提是你已经知道这道题要用dp去解,如何确定一个问题可以用dp去解,则是下一篇要讨论的话题。

    下面就是动态规划的一般解题思路:

    1. 分析最优解的特征。
    2. 递归地定义最优解的值。
    3. 计算最优解的值。
    4. 根据计算好的信息构造最优解。

    看起来非常抽象是吧?在这里不需要完全理解。等你看完全文再回来,保你会有不一样的感受。

    矩阵相乘问题

    问题

    这是一个看起来可能有点抽象的数学问题,但请你耐心往下看。当你看完解法时,你会惊异于动态规划的魔力。

    题目:给出一个由n个矩阵组成的矩阵链<A1,A2,...,An>,矩阵Ai的秩为pi-1×pi。将A1A2...An这个乘积全括号化,使得计算这个乘积所需要的的标量乘法最少。

    全括号化是以一种递归的形式定义的:

    一个全括号化的乘积只有两种可能:一是一个单个矩阵;二是两个全括号化的乘积的乘积。

    天啦也太绕了,举个例子吧。对于矩阵链<A1,A2,A3,A4>的乘积,共有五种全括号化的方法:

    (A1(A2(A3A4))),

    (A1((A2A3)A4)),

    ((A1A2)(A3A4)),

    (((A1A2)A3)A4),

    ((A1(A2A3))A4)

    我们知道矩阵乘法是满足结合律的,所以以上五个式子的乘积相等,但是它们的运算时间是否相等呢?

    矩阵乘法的运算时间

    我们知道,矩阵乘法的定义是:

    两个互相兼容的矩阵A,B可以相乘。互相兼容是指A的列数与B的行数相等。假如A是一个p×q的矩阵,而B是一个q×r的矩阵,则乘积C是一个p×r的矩阵且有

    cij = ∑ aik·bkj, k = 1,...,q.

    由于要对C中的每一个元素进行计算(共q·r个元素),而每次运算要做q次乘法,所以总的运算时间为pqr。

    来看看让乘积中的不同因子结合对运算时间有什么影响。假设我们有 <A1,A2,A3>这个矩阵链,三个矩阵的秩分别为10×100, 100×5和5×50。则

    • ((A1A2)A3)的运算时间为10×100×5+10×5×50=7500;
    • (A1(A2A3))的运算时间为100×5×50+10×100×50=75000。

    按照不同的顺序做矩阵乘法,所需要的乘法次数竟相差10倍。

    初步分析

    按照惯例,我们来感受一下穷举的算法复杂度。

    假设有一个长度为n的矩阵链,我们通过遍历所有的全括号化的可能性来解题。设全括号化的可能性数目为P(n)。当n为1时,矩阵链只有一个矩阵,符合全括号化的定义;当n>=2时,全括号化后为两个矩阵的乘积,即((...)(...))的形式。用递归的思路去分析,则中间两个括号的分界位置有n-1种可能,如下面竖线所示

    A1|A2|A3|...|An

    当分界线将矩阵链分为长度为k和n-k的两个子矩阵链时,全括号化可能性为P(k)P(n-k)。我们对所有的k值求和,就得出给整个矩阵链全括号化的数目:

    P(n) = ∑ P(k)P(n-k), k=1...n-1   (n>=2)

    这是一个卡塔兰数(Catalan Number),它的增长速率为Ω(4n/n3/2),它的渐进值为Ω(2n)

    对渐进值还不太熟,如果有小伙伴明白“增长速率”和“渐进值”之间的关系,欢迎指教。

    总的来说,如果对这个题目使用穷举法,算法复杂度是指数的。后面我们分析了dp的算法复杂度,再来比较。

    用dp方法论解题

    算法的学习永远没有“手把手”这一说。如果你在认真学习这篇文章,希望你能做到比你看到的小节思路提前一点。比如,在看第一步前,先对这个题目有一点大致思路,明白让自己迷茫的点在哪里;看第x步前,对第x步的内容在心中有一个猜测。这样做比起完全放弃思考,只是跟着文章的思路走,收获会大很多。

    第一步:分析最优解的特征

    这一步的精髓是分析最优子解如何构成最优解

    在上一节中已经提到,对于n>=2的情况,全括号化后为((chain_1)(chain_2))的形式。这样,问题自然而然地分成了两个子问题:求前后两个子括号中的最优解。

    假设对于某种特定的分割(即chain_1chain_2之间的分界线位置固定),chain_1的秩为m×p,其内部的标量乘法数目为x;chain_2的秩为p×n,其内部的标量乘法数目为y。则整个矩阵链的乘法次数为x+y+mpn。由于m,p,n是固定的,我们需要让x和y为最小值从而使整个矩阵链的乘法次数最小。即,对于某种特定的分割,两个子括号中的最优解构成整个问题的最优解的一个选项

    总结来说,我们将矩阵乘积简略地看成两个子矩阵的乘积,这两个子矩阵的分界有n-1种可能。对每一种可能,问题被分割成两个子问题,即求左右两个子矩阵链的最优解。如果遍历这n-1种可能并选出最好的一个,那就是整个问题的最优解。

    第二步:递归地定义最优解的值

    第二步非常关键,是我们将前后思路打通的一步。

    第一步中提出了一个比较简单的思路,即把矩阵链分割成左右两个子矩阵链。既然有了这个初步思路,我们就来涂鸦一番,看看这个思路是否可行。

    对于递归性的问题,一个很好的方法是画递归树,这样会使得问题看起来比较具象,而且也会暴露一些算法上的问题,比如重叠子树等。画递归树的时候,最好举一个实际的例子。这里我们假设有一个长度为4的矩阵链<A1,A2,A3,A4>,简单地画一下它的子问题分割:

     

    上图中的数字表示子矩阵链的长度,根为4,即初始矩阵链;它可以分为1+3,2+2,3+1三种情况,这三种情况又可以各自细分。

    这里暴露了一个问题,请看图中的两个涂色的子树。两个子树的节点数字是一样的。但是左边这个子树的根节点3代表的是A2A3A4这个乘积;而右边这个代表的是A1A2A3这个乘积。由于A1,A2,A3,A4四个矩阵的秩是未知的,它们很可能不相同,则A1A2A3A2A3A4的最优解也很有可能不同。换言之,它们并不是同一个子问题,它们的子子树也并不相同。

    这个问题意味着我们对子问题的定义不够严谨——子问题不能只用长度这个变量来确定。也就是说,如果在bottom-up的dp中用一个数组记录子问题的值,那么这个数组应该是一个二维数组。子问题不仅应该由子矩阵链的长度确定,还要加上起始index这样的信息。

    为了更通用一些,我们不用起始index+长度,而选用起始index+结束index的定义方法,这是二维dp的惯用套路,在许多字符串和数组有关的问题中都有用到。

    设用一个二位矩阵dp[][]存取子问题的解。定义dp[i][j](1<=i<=j<=n)的值为Ai...Aj的最小乘法次数。则按照以上的思路,可以把Ai...Aj再递归细分为子问题Ai...AkAk+1...Aj(i<=k<j),则Ai...Aj的最优解值为两个子问题最优解的和+两个子矩阵链相乘的乘法次数。即有

    i==j时,dp[i][j] = 0;

    i <j时,dp[i][j] = min{dp[i][k] + dp[k+1][j] + pi-1pkpj}, k = i...j-1 (p为各个矩阵的秩,见题目一节)

    到此为止,最关键的一步顺利完成啦(楼主写得好累,击掌╭(○`∀´○)╯╰(○'◡'○)╮)。在这一步中,我们递归地定义了子问题最优解的值,完成了算法最核心的设计部分。在后面两步中,我们只要把上面这两个式子翻译成代码,再注意一些实现细节就可以了。

    第三步:计算最优解的值

    细节一

    从第二步顺理成章,我们会在一个二维数组里记录子问题的解。但是按照什么顺序去填这个二维数组是个问题。

    还是举例子,在<A1,A2,A3,A4>这个矩阵链中,我们会有一个5×5的二维数组,随便挑选dp[1][4]这个元素举例。根据第二步中的状态转移方程,有

    dp[1][4] = min{(dp[1][1]+dp[2][4]+...),(dp[1][2]+dp[3][4]+...),(dp[1][3]+dp[4][4]+...)}

    省略号表示我们此处不需关注pi-1pkpj这一项,只需要看这个格子对其它格子的依赖是什么样子。

    由上图可以看出,要计算某一个元素(粉色边框),我们需要其左边下面的元素(同样深度的蓝色表示一组数据)。

    所以,我们的遍历方向是从下到上,从左到右

    细节二

    细心的读者可能注意到还有一个问题,就是我们一直在求“最优解的值”,也就是“最小的乘法次数”,可是题目中要求的是“最优解”,也就是“加括号的方式”。

    这两者并不矛盾,专注于求解前者可以让我们先思考相对简单的问题,通常在求解前者的过程中,我们也找出了后者,只是没有将它记录下来。

    在此题中,我们可以选择用一个同样的二维矩阵s[][]来记录后者,其中s[i][j]中记录Ai...Aj的分割分界线k。

    代码

     1     int matrixChain(int[] p){
     2         int n = p.length - 1; //number of matrices
     3         int[][] dp = new int[n + 1][n + 1]; //we need dp[1][n]
     4         int[][] s = new int[n + 1][n + 1];    //for storing of k
     5         for(int[] row : dp)
     6             Arrays.fill(row, Integer.MAX_VALUE);
     7 
     8         for(int i = 1; i <= n; i++)
     9             dp[i][i] = 0;    //dp[i][j] = 0 when i == j
    10         
    11         for(int i = n; i >= 1; i--)
    12             for(int j = i; j <= n; j++){
    13                 if(i == j){
    14                     dp[i][j] = 0;
    15                 }else{
    16                     for(int k = i; k < j; k++){
    17                         int count = dp[i][k] + dp[k+1][j] + p[i-1]*p[k]*p[j];
    18                         if(count < dp[i][j]){
    19                             dp[i][j] = count; //record optimal solution value
    20                             s[i][j] = k;      //record splitting point k
    21                         }
    22                     }
    23                 }
    24             }
    25         return dp[1][n];
    26     }

    运行一个例子:

    即输入的数组p为{30,35,15,5,10,20,25}。

    如果在return之前打印出dp[][]和s[][]的值,结果为:

          

    从左图可看出最优解为dp[1][6] = 15,125,即最少可以进行一万五千多次乘法。右图记录了对于每一个[i,j]决定的子矩阵链如何进行括号分割。

    顺便分享一个ArrayPrinter的util,可以直接用,能打印出上图那样的二维int数组。

     1 public class ArrayPrinter {
     2     public static void print(int[] arr){
     3         printReplacing(false, arr, 0,"");
     4     }
     5     
     6     public static void print(int[][] matrix){
     7         printReplacing(false, matrix, 0,"");
     8     }
     9     
    10     public static void printReplacing(int[] arr, int before, String after){
    11         printReplacing(true, arr, before, after);
    12     }
    13     
    14     public static void printReplacing(int[][] matrix, int before, String after){
    15         printReplacing(true, matrix, before, after);
    16     }
    17     
    18     /*--------------------------private utils-------------------------------*/
    19     
    20     private static void printReplacing(boolean replace, int[] arr, int before, String after){
    21         int maxLen = maxLength(arr);
    22         if(replace){
    23             for(int i : arr)
    24                 print(((i==before)?after:number(i)), maxLen);
    25         }else{
    26             for(int i : arr)
    27                 print(number(i), maxLen);
    28         }
    29         print("
    ", maxLen);
    30     }
    31     
    32     public static void printReplacing(boolean replace, int[][] matrix, int before, String after){
    33         int maxLen = maxLength(matrix);
    34         if(replace){
    35             for(int[] row : matrix){
    36                 for(int i : row)
    37                     print(((i==before)?after:number(i)), maxLen);
    38                 print("
    ", maxLen);
    39             }
    40         }else{
    41             for(int[] row : matrix){
    42                 for(int i : row)
    43                     print(number(i), maxLen);
    44                 print("
    ", maxLen);
    45             }
    46         }
    47     }
    48 
    49     private static int maxLength(int[] arr){
    50         int maxLen = 0;
    51         for(int aint : arr)
    52             maxLen = Math.max(Integer.toString(aint).length(), maxLen);
    53         return maxLen;
    54     }
    55     
    56     private static int maxLength(int[][] matrix){
    57         int maxLen = 0;
    58         for(int row[] : matrix)
    59             maxLen = Math.max(maxLength(row), maxLen);
    60         return maxLen;
    61     }
    62     
    63     //actual printing 
    64     private static void print(String s, int length){
    65         System.out.print(String.format("%1$"+(length+1)+"s", s));
    66     }
    67     
    68     //formatting of number
    69     private static String number(int i){
    70         return NumberFormat.getNumberInstance(Locale.US).format(i);
    71     } 
    72 }
    ArrayPrinter

    使用方法:

    1 ArrayPrinter.printReplacing(dp, Integer.MAX_VALUE, "/");
    2 ArrayPrinter.print(s);

    第四步:根据计算好的信息构造最优解

    还差一步就大功告成。这一步我们要拿着上一步计算出的矩阵s把最终的全括号矩阵乘积打印出来。递归打印即可。

     1     private void printParenthesis(int[][] s, int i, int j) {
     2         if(i == j)
     3             print("A"+i);
     4         else{
     5             print("(");
     6             printParenthesis(s, i, s[i][j]);
     7             printParenthesis(s, s[i][j]+1, j);
     8             print(")");
     9         }
    10     }

    打印结果:

    复杂度

    前面说过,穷举法的复杂度大概是O(2n)。在以上的dp算法中,主算法需要填满一个(n+1)×(n+1)的二维数组的上半部分,每填一个元素需要一个长度为j-i的循环,可通过这个思路对j-i进行求和(i=0...n, j=i...n),也可以通过大概估算得到时间复杂度为O(n3),远好于穷举法。

    空间复杂度主要由二维数组决定,为O(n2)。

    总结

    本文主要介绍了解一个dp问题的思路。

    dp问题一般有两个显著特点,这一点下一篇会详细讲述:

    • 问题的最优解由子问题的最优解构成
    • 子问题互相重叠

    也再复习一下解题的四个步骤,看你现在有没有更深刻的理解:

    1. 分析最优解的特征。               (分析最优子解如何构成最优解)
    2. 递归地定义最优解的值。               (画递归树,定义子问题,写状态转移方程)
    3. 计算最优解的值。                        (写代码求出最优解,如果有要求的话,记录额外信息,为第4步作准备)
    4. 根据计算好的信息构造最优解。       (从第3步记录的信息中构建最优解,在本题中就是括号的写法)

    参考资料

    算法导论(英文版)3rd Ed. 15.2

  • 相关阅读:
    MSSQL锁定1.Isolation level (myBased)
    等待状态CXPACKET分析
    拒绝了对对象 'sp_sdidebug'(数据库 'master',所有者 'dbo')的 EXECUTE 权限
    Oracle CBO 统计信息的收集与执行计划的选择
    Oracle 11gR1 on Win7
    读书笔记 <<你的知识需要管理>>
    ORA01555 总结
    Buffer Cache Management
    如何选择合适的索引
    书评 <SQL Server 2005 Performance Tuning性能调校> 竟然能够如此的不用心........
  • 原文地址:https://www.cnblogs.com/mozi-song/p/9629137.html
Copyright © 2011-2022 走看看