zoukankan      html  css  js  c++  java
  • pytorch模型转caffe模型

    Pytorch模型转换Caffe模型踩坑指南,代码使用的是Github上的工程,地址:https://github.com/longcw/pytorch2caffe

    操作环境:ubuntu = 14.04
             miniconda 3
             caffe
             pytorch = 0.2.0    torchvision = 0.1.8
             python = 2.7 
    

      

    环境配置:
    第一步 : 在miniconda创建一个虚拟环境pytorch2caffe : conda create -n pytorch2caffe python=2.7
    第二步 : 激活虚拟环境 source activate pytorch2caffe
    第三步 : 在该虚拟环境下安装对应版本的pytorch和torchvision : conda install pytorch=0.2.0 torchvision=0.1.8 安装完成后conda list看一看有没有安装成功对应的版本
    第四步 : 在该虚拟环境下编译caffe,官网指南链接 : http://caffe.berkeleyvision.org/install_apt.html
    第五步 : 编译配置caffe的python接口pycaffe,操作指南链接 : https://www.cnblogs.com/lyyang/p/6573846.html

    配置环境的时候因Github上没有关于这个项目的环境介绍,所以我以前用的是习惯的python 3.6和pytorch 0.4.0 ,而这个项目是用python 2.7写的,所以在创建虚拟环境的时候使用python 2.7的环境,要不然之后也会有很多问题,pytorch和torchvision版本很重要。

    博主调代码时老是遇到:KeyError: ‘ExpandBackward’这个错误,在改变pytorch和torchvision版本后这个问题解决了,这个Github的项目比较久了,并不支持新版本的Pytorch,这就是我在配置操作环境下帮大家踩得坑啦。

    在转换自己的模型之前,先调通项目中的demo,验证配置环境已经配置成功。项目中的demo是转换Google的inception-v3模型,我们根据终端里提示的模型地址将模型下载到本地,然后将模型载入进行转换。转换成功之后在项目的demo文件夹中生成model.prototxt与model.caffemodel两个新文件。测试的话将demo中设置test_mod = True,给定同样的随机输入数据,测试两个模型得出的结果。
    遇到的问题:

    caffe支持的卷积和池化层操作都是2D的,我的这个模型所做的卷积和池化操作都是1D的,当时找这个问题也花了很久的时间,没想到caffe只支持2D的操作。我将原来的input_size=(1, 1, 1024)修改成了(1, 1, 1, 1024),然后做相应的2D卷积和池化操作。

    我还遇到过的一个Import问题是:Segmentation fault (core dumped)这个问题的原因我不是很清楚,我在查看是哪个inport出问题时发现import caffe在import torch之后是并不会报这个错误,但是import torch之后再import caffe就会报这个错误。

    在调试代码时遇到过这个问题:ValueError: could not broadcast input array from shape (3,128) into shape (3,512),这个问题和caffe的源码有关,需要在caffe的proto文件中修改pooling层的参数optional bool ceil_mode = 13 [default = true],而因为caffe版本的原因,我的caffe并没有这个参数,所以要往现在caffe的源码pooling层中将ceil_mode相关参数和代码添加进去,然后重新编译caffe与pycaffe。

    1、在pooling_layer.hpp中往PoolingLayer类中添加ceil_mode_这个参数:
        int height_, width_;
        int pooled_height_, pooled_width_;
        bool global_pooling_;
        bool ceil_mode_;   //添加的类成员变量
        Blob<Dtype> rand_idx_;
        Blob<int> max_idx_;
    
    2、修改pooling_layer.cpp文件中相关参数,主要涉及到LayerSetUp函数和Reshape函数。
    
    LayerSetUp函数修改如下:
    || (!pool_param.has_stride_h() && !pool_param.has_stride_w()))
        << "Stride is stride OR stride_h and stride_w are required.";
    global_pooling_ = pool_param.global_pooling();
         ceil_mode_ = pool_param.ceil_mode();    //添加的代码,主要作用是从参数文件中获取ceil_mode_的参数数值。
    
    Reshape函数修改如下:
    if (global_pooling_) {
      kernel_h_ = bottom[0]->height();
      kernel_w_ = bottom[0]->width();
    
     //删除下面四行代码--------------------------------
     pooled_height_ = static_cast<int>(ceil(static_cast<float>(
     height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;      
     pooled_width_ = static_cast<int>(ceil(static_cast<float>(
     width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
     // Specify the structure by ceil or floor mode
    
      // 添加下面的代码------------------------
       if (ceil_mode_) {
     pooled_height_ = static_cast<int>(ceil(static_cast<float>(
         height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
     pooled_width_ = static_cast<int>(ceil(static_cast<float>(
         width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
       } else {
     pooled_height_ = static_cast<int>(floor(static_cast<float>(
         height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
     pooled_width_ = static_cast<int>(floor(static_cast<float>(
         width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
       }    
    
    3、修改caffe.proto文件中PoolingParameter的定义:
    // Specify floor/ceil mode
    optional bool ceil_mode = 13 [default = true];
    
    4、重新编译caffe与pycaffe
    

      


    原文:https://blog.csdn.net/weixin_38501242/article/details/82624071

  • 相关阅读:
    【WCF--初入江湖】04 WCF通信模式
    【WCF--初入江湖】03 配置服务
    c++输出左右对齐设置
    setw()函数
    clion更改大括号的位置
    emacs org-mode 中文手册精简版(纯小白)
    c++ string 类型 大小写转换 
    C++中string类型的find 函数
    string类型 C++
    统计单词数---单词与空格
  • 原文地址:https://www.cnblogs.com/qbdj/p/11024587.html
Copyright © 2011-2022 走看看