PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。他提供了大量的模型供我们所使用,如下图所示:
下面,我们选择其中一个网络进行使用,介绍如何使用、并修改 pytorch 本身为我们提供的现有网络。最后介绍一下模型的保存和修改。
pytorch 现有网络的使用与修改
下面我们以 VGG(Very Deep Convolutional Networks for Large-Scale Image Recognition)的使用为例,进行介绍该网络。
VGG 16 简介
VGG16网络是14年牛津大学计算机视觉组和Google DeepMind公司研究员一起研发的深度网络模型。该网络一共有16个训练参数的网络,该网络的具体网络结构如下所示:
不难看出,该网络主要用于对 224 x 224 的图像进行 1000 分类。下面我们查看 VGG 在 pytorch 上的官方文档。
VGG 16 doc
从帮助文档中,我们可以清楚的看到 pytorch 为我们提供了各种版本的 VGG,我们选择 VGG 16 进行查看。
VGG16 的简单使用
从 vgg 16的帮助文档可以得知,该模型训练的数据是 ImageNet
,我们进入 torchvision.datasets 查看 ImageNet
但是该数据集实在是太大了,根本下不了,还是不搞了。建立一个该网络的模型查看参数: ```python import torch import torchvision import torch.nn as nn # import torchvision.models
vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True)
print(vgg_model_original)
print(vgg_model_pretrained)
vgg_model_pretrained.add_module()
<p align="center">
<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091424359-823205952.png" style="zoom:100%"/>
</p>
<br/>
<p align="center">
<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091441147-673036772.png" style="zoom:100%"/>
</p>
<br/>
仔细查看这个网络的组成,你可以发现,组成该网络的一个个小 module 就是我们之前所介绍过的`Conv2d`, `ReLU`, `MaxPool2d`, `Linear`, `Dropout` 等等函数,
### VGG16 模型修改
经过上面的代码,我们可以较为轻松的看到 VGG16 神经网络的结构框架,那么我们如何修改别人已经写好的模型呢?
想要修改别人写好的模型,主要有一下这几种操作
<p align="center">
<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113092149496-1776250931.png" style="zoom:100%"/>
</p>
<br/>
选中模型,进行 add_module() 或者是直接对模型进行修改
<p align="center">
<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094536565-291308585.png" style="zoom:100%"/>
</p>
<br/>
<p align="center">
<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094752626-1549811574.png" style="zoom:100%"/>
</p>
<br/>
```python
import torch
import torchvision
import torch.nn as nn
# import torchvision.models
vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True)
print(vgg_model_original)
print(vgg_model_pretrained)
# vgg_model_pretrained.add_module()
vgg_model_original.classifier.add_module('15', nn.Linear(in_features=1000, out_features=10, bias=True))
print(vgg_model_original)
vgg_model_original.classifier[7] = nn.Linear(in_features=1000, out_features=15, bias=True)
print(vgg_model_original)
根据上诉代码,我们就将 1000 分类问题的网络修改成了 10 分类或者是 15 分类问题的网络了。
模型的保存和加载
当我们利用数据将模型训练好之后,往往需要保存模型。同时,当我们创建模型的时候,也可能需要加载我们之前已经训练好的参数,下面我来介绍一下操作方法。
保留模型结构和模型参数
通过 torch.save() 和 torch.load() 进行保存模型和参数
import torch
import torchvision
import torch.nn as nn
# import torchvision.models
vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
torch.save(vgg_model_pretrained, "../../models_param/vgg_model_pretrained.pth")
vgg_model_load = torch.load(f="../../models_param/vgg_model_pretrained.pth")
print(111)
打一个断点,查看保存模型和加载模型的参数情况
仅保留模型参数
同样是使用 save 和 load 参数,但是用法有所不同,他所保存的是一个模型参数,以字典dict 的形式保存
import torch
import torchvision
import torch.nn as nn
# import torchvision.models
vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
torch.save(vgg_model_pretrained.state_dict(), "../../models_param/vgg_model_pretrained_method2.pth")
vgg_model_load_method2 = torchvision.models.vgg16()
vgg_model_load_method2.load_state_dict(torch.load("../../models_param/vgg_model_pretrained_method2.pth"))
print("this is a breakpoint!")
断点查看 save 和 load 模型的参数情况
一模一样,没有问题。
Date:2021/11/13