zoukankan      html  css  js  c++  java
  • [Pytorch]Pytorch加载预训练模型(转)

    转自:https://blog.csdn.net/Vivianyzw/article/details/81061765

    东风的地方

    1. 直接加载预训练模型

    在训练的时候可能需要中断一下,然后继续训练,也就是简单的从保存的模型中加载参数权重:

    1. net = SNet()
    2. net.load_state_dict(torch.load("model_1599.pkl"))

    这种方式是针对于之前保存模型时以保存参数的格式使用的:

    torch.save(net.state_dict(), "model/model_1599.pkl")
    

    pytorch官网更推荐上述模型保存方法,也据说这种方式比下一种更快一点。

    下面介绍第二种模型保存和加载的方式:

    1. net = SNet()
    2. torch.save(net, "model_1599.pkl")
    3. snet = torch.load("model_1599.pkl")

    这种方式会将整个网络保存下来,数据量会更大,会消耗更多的时间,占用内存也更高。

    2. 加载一部分预训练模型

    模型可能是一些经典的模型改掉一部分,比如一般算法中提取特征的网络常见的会直接使用vgg16的features extraction部分,也就是在训练的时候可以直接加载已经在imagenet上训练好的预训练参数,这种方式实现如下:

    1. net = SNet()
    2. model_dict = net.state_dict()
    3. vgg16 = models.vgg16(pretrained=True)
    4. pretrained_dict = vgg16.state_dict()
    5. pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    6. model_dict.update(pretrained_dict)
    7. net.load_state_dict(model_dict)
    也就是在网络中state_dict部分,属于vgg16的,替换成vgg16预训练模型里的参数(代码里的k:v for k,v in pretrained_dict.items() if k in model_dict),其他保持不变。

    3. 微调经典网络

    因为pytorch中的torchvision给出了很多经典常用模型,并附加了预训练模型。利用好这些训练好的基础网络可以加快不少自己的训练速度。

    首先比如加载vgg16(带有预训练参数的形式):

    1. import torchvision.models as models
    2. vgg16 = models.vgg16(pretrained=True)

    比如,网络第一层本来是Conv2d(3, 64, 3, 1, 1),想修改成Conv2d(4, 64, 3, 1 ,1),那直接赋值就可以了:

    1. import torch.nn as nn
    2. vgg16.features[0]=nn.Conv2d(4, 64, 3, 1, 1)

    4. 修改经典网络

    这个比上面微调修改的地方要多一些,但是想介绍一下这样的修改方式。

    先简单介绍一下我需要需改的部分,在vgg16的基础模型下,每一个卷积都要加一个dropout层,并将ReLU激活函数换成PReLU,最后两层的Pooling层stride改成1。直接上代码:

    1. def feature_layer():
    2. layers = []
    3. pool1 = ['4', '9', '16']
    4. pool2 = ['23', '30']
    5. vgg16 = models.vgg16(pretrained=True).features
    6. for name, layer in vgg16._modules.items():
    7. if isinstance(layer, nn.Conv2d):
    8. layers += [layer, nn.Dropout2d(0.5), nn.PReLU()]
    9. elif name in pool1:
    10. layers += [layer]
    11. elif name == pool2[0]:
    12. layers += [nn.MaxPool2d(2, 1, 1)]
    13. elif name == pool2[1]:
    14. layers += [nn.MaxPool2d(2, 1, 0)]
    15. else:
    16. continue
    17. features = nn.Sequential(*layers)
    18. #feat3 = features[0:24]
    19. return features

    大概的思路就是,创建一个新的网络(layers列表), 遍历vgg16里每一层,如果遇到卷积层(if isinstance(layer, nn.Conv2d)就先把该层(Conv2d)保持原样加进去,随后增加一个dropout层,再加一个PReLU层。然后如果遇到最后两层pool,就修改响应参数加进去,其他的pool正常加载。 最后将这个layers列表转成网络的nn.Sequential的形式,最后返回features。然后再你的新的网络层就可以用以下方式来加载:

    1. class SNet(nn.Module):
    2. def __init__(self):
    3. super(SNet, self).__init__()
    4. self.features = feature_layer()
    5. def forward(self, x):
    6. x = self.features(x)
    7. return x
  • 相关阅读:
    时间复杂度为O(1)的Iveely搜索缓存策略
    数据挖掘十大算法决策树的实现
    编写有效的C# 代码(一)
    数据挖掘十大算法Kmeans之图像区域选择
    asp.net 导出 Excel
    使用XMLSpyDocEditPlugIn2.dll,页面加载失败
    多线程Thread的使用,并使用Thread传参
    ajaxpro.2.dll的使用
    论道WP(二):如何学习WP开发?
    IList<T> 转换成 DataSet
  • 原文地址:https://www.cnblogs.com/kk17/p/10156160.html
Copyright © 2011-2022 走看看