zoukankan      html  css  js  c++  java
  • 迁移pytorch工程至matlab

    好久没写博客了,险些以为自己找不到密码了。

    最近抽空参与了个小项目,很惭愧,只做了三件小事


    1. 基于PyTorch训练了一系列单图像超分辨神经网络

    基于PyTorch训练了一系列单图像超分辨神经网络,超分辨系数从2-10。
    该部分的实现参考了pytorch官方repo中的SR例程,训练程序包含于`./train`文件夹。该项目
    基于高效子像素卷积层[1]进行空间分辨率提升操作,训练速度极快。

    [1] ["Shi W, Caballero J, Huszar F, et al. Real-Time Single Image and
        Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
        Neural Network[J]. 2016:1874-1883.](https://arxiv.org/abs/1609.05158)

    2. 把训练好的模型权值转存为MATLAB文件。


    简单粗暴,异常直接,只要把对应卷积层的权值全部提取出来就可以了。

    提取的时候注意一点,要把pytorch中的Variable格式转换为Tensor,再转换为CPU模式,最终转换为numpy数组。

    这一系列过程合并起来就是:

    Var.data.cpu().numpy()

    具体实现如下:

     1 from __future__ import print_function
     2 
     3 import torch
     4 import numpy as np
     5 import scipy.io as sio
     6 
     7 for i in [2, 3, 4, 5, 6, 7, 8, 9, 10]:
     8 
     9     model_name = 'model_upscale_{}_epoch_101.pth'.format(i)
    10     model = torch.load(model_name)
    11     print(model._modules)
    12 
    13     weight = dict()
    14     weight['conv1_w'] = model._modules['conv1']._parameters['weight'].data.cpu().numpy()
    15     weight['conv2_w'] = model._modules['conv2']._parameters['weight'].data.cpu().numpy()
    16     weight['conv3_w'] = model._modules['conv3']._parameters['weight'].data.cpu().numpy()
    17     weight['conv4_w'] = model._modules['conv4']._parameters['weight'].data.cpu().numpy()
    18 
    19     weight['conv1_b'] = model._modules['conv1']._parameters['bias'].data.cpu().numpy()
    20     weight['conv2_b'] = model._modules['conv2']._parameters['bias'].data.cpu().numpy()
    21     weight['conv3_b'] = model._modules['conv3']._parameters['bias'].data.cpu().numpy()
    22     weight['conv4_b'] = model._modules['conv4']._parameters['bias'].data.cpu().numpy()
    23 
    24     sio.savemat('model_upscale_{}.mat'.format(i), mdict=weight)

    3. 把网络的test过程移植到了MATLAB平台,并撰写了测试代码。

    把卷积层和pixelshuffle层用matlab重写了一下。

    复现pixelshuffle层的时候遇到了一些麻烦,又回头看了下pytorch里的测试代码

    `https://github.com/pytorch/pytorch/blob/master/test/test_nn.py `

    # https://github.com/pytorch/pytorch/blob/master/test/test_nn.py
    def _verify_pixel_shuffle(self, input, output, upscale_factor):
        for c in range(output.size(1)):
            for h in range(output.size(2)):
                for w in range(output.size(3)):
                    height_idx = h // upscale_factor
                    weight_idx = w // upscale_factor
                    channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + 
                                  (c * upscale_factor ** 2)
        self.assertEqual(output[:, c, h, w], input[:, channel_idx, height_idx, weight_idx])

    理了理思路,改写成MATLAB代码:

     1 function [ outputs ] = PixelShuffle( inputs, upscale_factor )
     2 %    PixelShuffle :
     3 %
     4 %   input : N, upscale_factor ** 2, H, W
     5 %   output : N, 1, H*upscale_factor, W*upscale_factor
     6 
     7 [N, ~, H, W] = size(inputs);
     8 H_out = H*upscale_factor;
     9 W_out = W*upscale_factor;
    10 outputs = zeros([N, 1, H_out, W_out]);
    11 for i = 1:N
    12     for h = 1: H_out
    13         for w = 1:W_out
    14             height_idx = floor(h / upscale_factor+0.5);
    15             weight_idx = floor(w / upscale_factor+0.5);
    16             channel_idx = (upscale_factor * mod(h-1, upscale_factor)) + mod(w-1, upscale_factor)+1;
    17             outputs(i, 1, h, w) = inputs(i, channel_idx, height_idx, weight_idx);
    18         end
    19     end
    20 end
    21 end

    4. 完整工程github链接。

    https://github.com/JiJingYu/super-resolution-by-subpixel-convolution

    模型权值已保存为matlab权值,直接在matlab中运行`demo.m`文件即可验证

  • 相关阅读:
    Azure PowerShell (7) 使用CSV文件批量设置Virtual Machine Endpoint
    Windows Azure Cloud Service (39) 如何将现有Web应用迁移到Azure PaaS平台
    Azure China (7) 使用WebMetrix将Web Site发布至Azure China
    Microsoft Azure News(4) Azure新D系列虚拟机上线
    Windows Azure Cloud Service (38) 微软IaaS与PaaS比较
    Windows Azure Cloud Service (37) 浅谈Cloud Service
    Azure PowerShell (6) 设置单个Virtual Machine Endpoint
    Azure PowerShell (5) 使用Azure PowerShell创建简单的Azure虚拟机和Linux虚拟机
    功能代码(1)---通过Jquery来处理复选框
    案例1.用Ajax实现用户名的校验
  • 原文地址:https://www.cnblogs.com/nwpuxuezha/p/7834344.html
Copyright © 2011-2022 走看看