zoukankan      html  css  js  c++  java
  • [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码:

    这篇文章主要介绍了序列化与反序列化,以及 PyTorch 中的模型保存于加载的两种方式,模型的断点续训练。

    序列化与反序列化

    模型在内存中是以对象的逻辑结构保存的,但是在硬盘中是以二进制流的方式保存的。

    • 序列化是指将内存中的数据以二进制序列的方式保存到硬盘中。PyTorch 的模型保存就是序列化。

    • 反序列化是指将硬盘中的二进制序列加载到内存中,得到模型的对象。PyTorch 的模型加载就是反序列化。

    PyTorch 中的模型保存与加载

    torch.save

    torch.save(obj, f, pickle_module, pickle_protocol=2, _use_new_zipfile_serialization=False)
    

    主要参数:

    • obj:保存的对象,可以是模型。也可以是 dict。因为一般在保存模型时,不仅要保存模型,还需要保存优化器、此时对应的 epoch 等参数。这时就可以用 dict 包装起来。
    • f:输出路径

    其中模型保存还有两种方式:

    保存整个 Module

    这种方法比较耗时,保存的文件大

    torch.savev(net, path)
    

    只保存模型的参数

    推荐这种方法,运行比较快,保存的文件比较小

    state_sict = net.state_dict()
    torch.savev(state_sict, path)
    

    下面是保存 LeNet 的例子。在网络初始化中,把权值都设置为 2020,然后保存模型。

    import torch
    import numpy as np
    import torch.nn as nn
    from common_tools import set_seed
    
    
    class LeNet2(nn.Module):
        def __init__(self, classes):
            super(LeNet2, self).__init__()
            self.features = nn.Sequential(
                nn.Conv2d(3, 6, 5),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                nn.Conv2d(6, 16, 5),
                nn.ReLU(),
                nn.MaxPool2d(2, 2)
            )
            self.classifier = nn.Sequential(
                nn.Linear(16*5*5, 120),
                nn.ReLU(),
                nn.Linear(120, 84),
                nn.ReLU(),
                nn.Linear(84, classes)
            )
    
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size()[0], -1)
            x = self.classifier(x)
            return x
    
        def initialize(self):
            for p in self.parameters():
                p.data.fill_(2020)
    
    
    net = LeNet2(classes=2019)
    
    # "训练"
    print("训练前: ", net.features[0].weight[0, ...])
    net.initialize()
    print("训练后: ", net.features[0].weight[0, ...])
    
    path_model = "./model.pkl"
    path_state_dict = "./model_state_dict.pkl"
    
    # 保存整个模型
    torch.save(net, path_model)
    
    # 保存模型参数
    net_state_dict = net.state_dict()
    torch.save(net_state_dict, path_state_dict)
    

    运行完之后,文件夹中生成了``model.pklmodel_state_dict.pkl`,分别保存了整个网络和网络的参数

    torch.load

    torch.load(f, map_location=None, pickle_module, **pickle_load_args)
    

    主要参数:

    • f:文件路径
    • map_location:指定存在 CPU 或者 GPU。

    加载模型也有两种方式

    加载整个 Module

    如果保存的时候,保存的是整个模型,那么加载时就加载整个模型。这种方法不需要事先创建一个模型对象,也不用知道模型的结构,代码如下:

    path_model = "./model.pkl"
    net_load = torch.load(path_model)
    
    print(net_load)
    

    输出如下:

    LeNet2(
      (features): Sequential(
        (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
        (4): ReLU()
        (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (classifier): Sequential(
        (0): Linear(in_features=400, out_features=120, bias=True)
        (1): ReLU()
        (2): Linear(in_features=120, out_features=84, bias=True)
        (3): ReLU()
        (4): Linear(in_features=84, out_features=2019, bias=True)
      )
    )
    

    只加载模型的参数

    如果保存的时候,保存的是模型的参数,那么加载时就参数。这种方法需要事先创建一个模型对象,再使用模型的load_state_dict()方法把参数加载到模型中,代码如下:

    path_state_dict = "./model_state_dict.pkl"
    state_dict_load = torch.load(path_state_dict)
    net_new = LeNet2(classes=2019)
    
    print("加载前: ", net_new.features[0].weight[0, ...])
    net_new.load_state_dict(state_dict_load)
    print("加载后: ", net_new.features[0].weight[0, ...])
    

    模型的断点续训练

    在训练过程中,可能由于某种意外原因如断点等导致训练终止,这时需要重新开始训练。断点续练是在训练过程中每隔一定次数的 epoch 就保存模型的参数和优化器的参数,这样如果意外终止训练了,下次就可以重新加载最新的模型参数和优化器的参数,在这个基础上继续训练。

    下面的代码中,每隔 5 个 epoch 就保存一次,保存的是一个 dict,包括模型参数、优化器的参数、epoch。然后在 epoch 大于 5 时,就break模拟训练意外终止。关键代码如下:

        if (epoch+1) % checkpoint_interval == 0:
    
            checkpoint = {"model_state_dict": net.state_dict(),
                          "optimizer_state_dict": optimizer.state_dict(),
                          "epoch": epoch}
            path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
            torch.save(checkpoint, path_checkpoint)
    

    在 epoch 大于 5 时,就break模拟训练意外终止

        if epoch > 5:
            print("训练意外中断...")
            break
    

    断点续训练的恢复代码如下:

    path_checkpoint = "./checkpoint_4_epoch.pkl"
    checkpoint = torch.load(path_checkpoint)
    
    net.load_state_dict(checkpoint['model_state_dict'])
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    start_epoch = checkpoint['epoch']
    
    scheduler.last_epoch = start_epoch
    

    需要注意的是,还要设置scheduler.last_epoch参数为保存的 epoch。模型训练的起始 epoch 也要修改为保存的 epoch。

    参考资料


    如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

  • 相关阅读:
    Linux常用命令(5)--SSH访问远程服务器、SCP服务器间文件拷贝
    【转载】善用工具(1)--Mac版UltraEdit编辑器破解方法
    Linux常用命令(4)--善用"help"、"man在线帮助文档",轻松搞定系统命令
    Linux常用命令(3)--文件管理(查看文件大小权限信息、修改文件所属用户和操作权限、压缩解压文件)
    Linux常用命令(2)--vi (vim)文本编辑工具
    Linux常用命令(1)--用户管理(添加用户、修改密码、授予root权限)
    30分钟掌握ES6/ES2015核心内容(下)
    30分钟掌握ES6/ES2015核心内容(上)
    99%的人都理解错了HTTP中GET与POST的区别
    js中const,var,let区别
  • 原文地址:https://www.cnblogs.com/zhangxiann/p/13673807.html
Copyright © 2011-2022 走看看