zoukankan      html  css  js  c++  java
  • 模型处理-08

      模型是神经网络训练优化后得到的成果, 包含了神经网络骨架及学习得到的参数。 PyTorch对于模型的处理提供了丰富的工具, 本节将从模型的生成、 预训练模型的加载和模型保存3个方面进行介绍。

    1. 网络模型库: torchvision.models
    对于深度学习, torchvision.models库提供了众多经典的网络结构与预训练模型, 例如VGGResNetInception等, 利用这些模型可以快速搭建物体检测网络, 不需要逐层手动实现。 torchvision包与PyTorch相独立, 需要通过pip指令进行安装, 如下:

    1 pip install torchvision # 适用于Python 2
    2 pip3 install torchvision # 适用于Python 3 
    View Code

    VGG模型为例, 在torchvision.models中, VGG模型的特征层与分类层分别用vgg.featuresvgg.classifier来表示, 每个部分是一个nn.Sequential结构, 可以方便地使用与修改。 下面讲解如何使用torchvision.model模块。

     1 from torch import nn
     2 from torchvision import models
     3 
     4 # 通过torchvision.model直接调用VGG16的网络结构
     5 vgg = models.vgg16()
     6 
     7 # VGG16的特征层包括13个卷积、 13个激活函数ReLU、 5个池化, 一共31层
     8 print(len(vgg.features))
     9 >> 31
    10 
    11 # VGG16的分类层包括3个全连接、 2个ReLU、 2个Dropout, 一共7层
    12 print(len(vgg.classifier))
    13 >> 7
    14 
    15 # 可以通过出现的顺序直接索引每一层
    16 print(vgg.classifier[-1])
    17 >> Linear(in_features=4096, out_features=1000, bias=True)
    18 
    19 # 也可以选取某一部分, 如下代表了特征网络的最后一个卷积模组
    20 print(vgg.features[24:])
    21 >> Sequential(
    22     (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    23     (25): ReLU(inplace)
    24     (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    25     (27): ReLU(inplace)
    26     (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    27     (29): ReLU(inplace)
    28     (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    29   )
    View Code

    2. 加载预训练模型
    对于计算机视觉的任务, 包括物体检测, 我们通常很难拿到很大的数据集, 在这种情况下重新训练一个新的模型是比较复杂的, 并且不容易调整, 因此, Fine-tune(微调) 是一个常用的选择。 所谓Fine-tune是指利用别人在一些数据集上训练好的预训练模型, 在自己的数据集上训练自己的模型。

    在具体使用时, 通常有两种情况, 第一种是直接利用torchvision.models中自带的预训练模型, 只需要在使用时赋予pretrained参数为True即可。

    1 from torch import nn
    2 from torchvision import models
    3 
    4 # 通过torchvision.model直接调用VGG16的网络结构
    5 vgg = models.vgg16(pretrained=True)
    View Code

    第二种是如果想要使用自己的本地预训练模型, 或者之前训练过的模型, 则可以通过model.load_state_dict()函数操作, 具体如下:

     1 import torch
     2 from torch import nn
     3 from torchvision import models
     4 
     5 # 通过torchvision.model直接调用VGG16的网络结构
     6 vgg = models.vgg16()
     7 static_dict = torch.load(" your model path")
     8 
     9 # 利用load_state_dict, 遍历预训练模型的关键字, 如果出现在了VGG中, 则加载预训练参数
    10 vgg.load_state_dict({k:v for k,v in state_dict_items() if k in vgg.state_dict()})
    View Code

    通常来讲, 对于不同的检测任务, 卷积网络的前两三层的作用是非常类似的, 都是提取图像的边缘信息等, 因此为了保证模型训练中能够更加稳定, 一般会固定预训练网络的前两三个卷积层而不进行参数的学习。 例如VGG模型, 可以设置前三个卷积模组不进行参数学习, 设置方式如下:

    1 for layer in range(10):
    2    for p in vgg[layer].parameters():
    3       p.requires_grad = False
    View Code

    3. 模型保存

    PyTorch中, 参数的保存通过torch.save()函数实现, 可保存对象包括网络模型、 优化器等, 而这些对象的当前状态数据可以通过自身的state_dict()函数获取。

    1 torch.save({
    2 ‘model’: model.state_dict(),
    3 'optimizer': optimizer.state_dict(),
    4 'model_path.pth')
    View Code


  • 相关阅读:
    10. Regular Expression Matching
    9. Palindrome Number
    6. ZigZag Conversion
    5. Longest Palindromic Substring
    4. Median of Two Sorted Arrays
    3. Longest Substring Without Repeating Characters
    2. Add Two Numbers
    链式表的按序号查找
    可持久化线段树——区间更新hdu4348
    主席树——树链上第k大spoj COT
  • 原文地址:https://www.cnblogs.com/zhaopengpeng/p/13641485.html
Copyright © 2011-2022 走看看