zoukankan      html  css  js  c++  java
  • (原)torch模型转pytorch模型

    转载请注明出处:

    http://www.cnblogs.com/darkknightzh/p/7839263.html

    目前使用的torch模型转pytorch模型的程序为:

    https://github.com/clcarwin/convert_torch_to_pytorch

    该程序中,常见的模型都可以转换,但是对于torch中为BatchNormalization的则会提示出错:

    Not Implement BatchNormalization

    torch中的SpatialBatchNormalization对应于输入为4d的特征(batchsize*featdim*featHeight*featWidth),对应于pytorch中的nn.BatchNorm2d

    而torch中的BatchNormalization对应于输入为2d的特征(batchsize*featdim),对应于pytorch中的nn.BatchNorm1d

    因而修改方法很简单:

    1. 在convert_torch.py的74行(elif name == 'ReLU':)之前添加:

    elif name == 'BatchNormalization':
        n = nn.BatchNorm1d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
        copy_param(m,n)
        add_submodule(seq,n)

    2. 在convert_torch.py的(未修改前的)175行(elif name == 'ReLU':)之前添加:

    elif name == 'BatchNormalization':
        s += ['nn.BatchNorm1d({},{},{},{}),#BatchNorm1d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]

    3. 在convert_torch.py的(未修改前的)245行(s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s))之前添加:

    s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm1d',')'),s)
    s = map(lambda x: x.replace('),#BatchNorm1d',')'),s)

    经过上述修改后,torch模型中含有BatchNormalization,转换到pytorch后的模型性能和转换前的模型性能一致。

    顺便说一下,2天前更新的该程序,添加了BatchNorm3d的支持,但是在243、244行之后,并没有增加BatchNorm3d的相关代码,不清楚是否会有问题。我这边没有用到BatchNorm3d,因而没有测试。

    另一方面,上面的3步中,我是根据BatchNorm2d去修改,没有测试如果不修改某一步(如第3步),程序是否会有问题。反正都改了,模型没有问题。。。

  • 相关阅读:
    Matlab中save与load函数的使用
    bsxfun函数
    matlab中nargin函数的用法
    Leetcode 188. Best Time to Buy and Sell Stock IV
    Leetcode 123. Best Time to Buy and Sell Stock III
    leetcode 347. Top K Frequent Elements
    Leetcode 224. Basic Calculator
    Leetcode 241. Different Ways to Add Parentheses
    Leetcode 95. Unique Binary Search Trees II
    Leetcode 96. Unique Binary Search Trees
  • 原文地址:https://www.cnblogs.com/darkknightzh/p/7839263.html
Copyright © 2011-2022 走看看