zoukankan      html  css  js  c++  java
  • 7. pytorch 现有网络模型的使用与修改和模型的保存与加载

      PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。他提供了大量的模型供我们所使用,如下图所示:


    下面,我们选择其中一个网络进行使用,介绍如何使用、并修改 pytorch 本身为我们提供的现有网络。最后介绍一下模型的保存和修改。

    pytorch 现有网络的使用与修改

      下面我们以 VGG(Very Deep Convolutional Networks for Large-Scale Image Recognition)的使用为例,进行介绍该网络。

         VGG 16 简介

      VGG16网络是14年牛津大学计算机视觉组和Google DeepMind公司研究员一起研发的深度网络模型。该网络一共有16个训练参数的网络,该网络的具体网络结构如下所示:


      不难看出,该网络主要用于对 224 x 224 的图像进行 1000 分类。下面我们查看 VGG 在 pytorch 上的官方文档。

         VGG 16 doc

      从帮助文档中,我们可以清楚的看到 pytorch 为我们提供了各种版本的 VGG,我们选择 VGG 16 进行查看。


         VGG16 的简单使用

      从 vgg 16的帮助文档可以得知,该模型训练的数据是 ImageNet,我们进入 torchvision.datasets 查看 ImageNet


    但是该数据集实在是太大了,根本下不了,还是不搞了。建立一个该网络的模型查看参数: ```python import torch import torchvision import torch.nn as nn # import torchvision.models

    vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
    vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True)

    print(vgg_model_original)
    print(vgg_model_pretrained)

    vgg_model_pretrained.add_module()

    
    
    <p align="center">
    	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091424359-823205952.png" style="zoom:100%"/>
    </p>
    <br/>
    
    
    <p align="center">
    	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091441147-673036772.png" style="zoom:100%"/>
    </p>
    <br/>
    
    仔细查看这个网络的组成,你可以发现,组成该网络的一个个小 module 就是我们之前所介绍过的`Conv2d`, `ReLU`, `MaxPool2d`, `Linear`, `Dropout` 等等函数,
    
    
    ### &nbsp;&nbsp;&nbsp;&nbsp; VGG16 模型修改
    &nbsp;&nbsp;经过上面的代码,我们可以较为轻松的看到 VGG16 神经网络的结构框架,那么我们如何修改别人已经写好的模型呢?
    &nbsp;&nbsp;想要修改别人写好的模型,主要有一下这几种操作
    
    
    <p align="center">
    	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113092149496-1776250931.png" style="zoom:100%"/>
    </p>
    <br/>
    
    选中模型,进行 add_module() 或者是直接对模型进行修改
    
    <p align="center">
    	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094536565-291308585.png" style="zoom:100%"/>
    </p>
    <br/>
    
    
    <p align="center">
    	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094752626-1549811574.png" style="zoom:100%"/>
    </p>
    <br/>
    
    ```python
    import torch
    import torchvision
    import torch.nn as nn
    # import torchvision.models
    
    vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
    vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True)
    
    print(vgg_model_original)
    print(vgg_model_pretrained)
    # vgg_model_pretrained.add_module()
    vgg_model_original.classifier.add_module('15', nn.Linear(in_features=1000, out_features=10, bias=True))
    print(vgg_model_original)
    vgg_model_original.classifier[7] = nn.Linear(in_features=1000, out_features=15, bias=True)
    print(vgg_model_original)
    

    根据上诉代码,我们就将 1000 分类问题的网络修改成了 10 分类或者是 15 分类问题的网络了。

    模型的保存和加载

      当我们利用数据将模型训练好之后,往往需要保存模型。同时,当我们创建模型的时候,也可能需要加载我们之前已经训练好的参数,下面我来介绍一下操作方法。

         保留模型结构和模型参数

    通过 torch.save() 和 torch.load() 进行保存模型和参数

    import torch
    import torchvision
    import torch.nn as nn
    # import torchvision.models
    
    vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
    
    torch.save(vgg_model_pretrained, "../../models_param/vgg_model_pretrained.pth")
    
    vgg_model_load = torch.load(f="../../models_param/vgg_model_pretrained.pth")
    print(111)
    

    打一个断点,查看保存模型和加载模型的参数情况



         仅保留模型参数

      同样是使用 save 和 load 参数,但是用法有所不同,他所保存的是一个模型参数,以字典dict 的形式保存

    import torch
    import torchvision
    import torch.nn as nn
    # import torchvision.models
    
    vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
    
    torch.save(vgg_model_pretrained.state_dict(), "../../models_param/vgg_model_pretrained_method2.pth")
    
    vgg_model_load_method2 = torchvision.models.vgg16()
    vgg_model_load_method2.load_state_dict(torch.load("../../models_param/vgg_model_pretrained_method2.pth"))
    print("this is a breakpoint!")
    

    断点查看 save 和 load 模型的参数情况



    一模一样,没有问题。

    Author:luckylight(xyg)
    Date:2021/11/13
  • 相关阅读:
    神经网络
    密度峰值聚类
    kylin从入门到实战:实际案例
    [时间序列分析][3]--自相关系数和偏自相关系数
    时间序列分析之指数平滑法(holt-winters及代码)
    时间序列模型
    python3.5如何安装statsmodels包?
    时间序列分析和预测
    Xshell6和Xftp6 破解免安装版,无窗口多开限制
    优化问题
  • 原文地址:https://www.cnblogs.com/lucky-light/p/15547358.html
Copyright © 2011-2022 走看看