zoukankan      html  css  js  c++  java
  • mxnet下如何查看中间结果

    https://blog.csdn.net/disen10/article/details/79376631

    固定权重:https://www.cnblogs.com/chenyliang/p/6780019.html

    固定权重:https://discuss.gluon.ai/t/topic/1164

    查看权重

    在训练过程中,有时候我们为了debug而需要查看中间某一步的权重信息,在mxnet中,我们可以很方便的调用get_params()方法来得到权重信息。

    1.  
      '''
    2.  
      查看权重示例代码
    3.  
      转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
    4.  
      '''
    5.  
      import mxnet as mx
    6.  
      sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
    7.  
      mod = mx.mod.Module(symbol=sym,context=mx.gpu()) #创建Module
    8.  
      mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
    9.  
      mod.set_params(arg_params,aux_params)
    10.  
      import numpy as np
    11.  
      import cv2
    12.  
      def get_image(filename):
    13.  
      img = cv2.imread(filename)
    14.  
      img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    15.  
      img = cv2.resize(img,(224,224))
    16.  
      img = np.swapaxes(img,0,2)
    17.  
      img = np.swapaxes(img,1,2)
    18.  
      img = img[np.newaxis,:]
    19.  
      return img
    20.  
      from collections import namedtuple
    21.  
      Batch = namedtuple('Batch',['data'])
    22.  
      img = get_image('val_1000/0.jpg') #获取图片
    23.  
      mod.forward(Batch([mx.nd.array(img)])) #预测结果
    24.  
      ################################################
    25.  
      #debug模式下,获取权重信息
    26.  
      keys = mod.get_params()[0].keys() # 列出所有权重名称
    27.  
      conv_w = mod.get_params()[0]['conv0_weight'] #获取想要查看的权重信息,如conv_weight
    28.  
      print conv_w.asnumpy() #查看具体数值
    29.  
      ################################################
    30.  
      prob = mod.get_outputs()[0].asnumpy()
    31.  
      y = np.argsort(np.squeeze(prob))[::-1]
    32.  
      print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    查看中间输出结果

    由于mxnet的网络由symbol组成,而symbol又属于符号式编程,所以我们不能像上面查看权重一样直接查看,我们需要把我们想看的输出结果保存下来。

    1.  
      '''
    2.  
      方法一
    3.  
      查看中间结果代码
    4.  
      转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
    5.  
      '''
    6.  
      import mxnet as mx
    7.  
      net = mx.symbol.Variable('data')
    8.  
      fc1 = mx.symbol.FullyConnected(data=net, name='fc1', num_hidden=128)
    9.  
      net = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
    10.  
      net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=64)
    11.  
      out = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    12.  
      # 通过把两个输出组成一个group来得到自己需要查看的中间层输出结果
    13.  
      group = mx.symbol.Group([fc1, out])
    14.  
      print group.list_outputs()
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    1.  
      '''
    2.  
      方法二
    3.  
      有时候我们使用别人的模型,所以无法像方法一一样在定义模型的时候就确定需要查看的中间层输出结果,
    4.  
      这时候我们使用get_internals()方法来查找自己需要查看的中间层
    5.  
      转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
    6.  
      '''
    7.  
      import mxnet as mx
    8.  
      sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
    9.  
      ########################################################################
    10.  
      args = sym.get_internals().list_outputs() #获得所有中间输出
    11.  
      internals = model.symbol.get_internals()
    12.  
      fc1 = internals['fc1_output']
    13.  
      conv = internals['stage4_unit3_conv1_output']
    14.  
      group = mx.symbol.Group([fc1, sym, conv]) #把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
    15.  
      #########################################################################
    16.  
      mod = mx.mod.Module(symbol=group,context=mx.gpu()) #创建Module
    17.  
      mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
    18.  
      mod.set_params(arg_params,aux_params)
    19.  
      import numpy as np
    20.  
      import cv2
    21.  
      def get_image(filename):
    22.  
      img = cv2.imread(filename)
    23.  
      img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    24.  
      img = cv2.resize(img,(224,224))
    25.  
      img = np.swapaxes(img,0,2)
    26.  
      img = np.swapaxes(img,1,2)
    27.  
      img = img[np.newaxis,:]
    28.  
      return img
    29.  
      from collections import namedtuple
    30.  
      Batch = namedtuple('Batch',['data'])
    31.  
      img = get_image('val_1000/0.jpg') #获取图片
    32.  
      mod.forward(Batch([mx.nd.array(img)])) #预测结果
    33.  
      prob = mod.get_outputs()[0].asnumpy()
    34.  
      y = np.argsort(np.squeeze(prob))[::-1]
    35.  
      print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))
  • 相关阅读:
    深入Spring之IOC之加载BeanDefinition
    Hexo+GitHub Actions 完美打造个人博客
    Spring中资源的加载原来是这么一回事啊!
    Web 跨域请求问题的解决方案- CORS 方案
    重新认识 Spring IOC
    Spring Data Jpa 入门学习
    前奏:Spring 源码环境搭建
    最短路径——floyd算法代码(c语言)
    leetcode 第184场周赛第一题(数组中的字符串匹配)
    如何用尾插法建立双链表(C语言,非循环)
  • 原文地址:https://www.cnblogs.com/jukan/p/10197604.html
Copyright © 2011-2022 走看看