zoukankan      html  css  js  c++  java
  • Pytorch学习笔记16----CNN或LSTM模型保存与加载

    1.三个核心函数

    介绍一系列关于 PyTorch 模型保存与加载的应用场景,主要包括三个核心函数:

    (1)torch.save

    其中,应用了 Python 的 pickle 包,进行序列化,可适用于模型Models,张量Tensors,以及各种类型的字典对象的序列化保存.

    (2)torch.load

    采用 Python 的 pickle 的 unpickling 函数,对磁盘 pickled 的对象文件进行反序列化(deserialize),加载到内存.

    (3)torch.nn.Module.load_state_dict

    采用序列化的 state_dict 加载模型参数(字典).

    2.state_dict介绍

    PyTorch中,torch.nn.Module 模型中的可学习参数(learnable parameters)(如,weights 和 biases),包含在模型参数(model parameters)里(根据 model.parameters() 进行访问.)

    state_dict可以简单的理解为 Python 的字典对象,其将每一层映射到其参数张量.

    注,只有包含待学习参数的网络层,如卷积层,线性连接层等,会在模型的 state_dict 中有元素值.

    优化器对象(Optimizer object,torch.optim) 也有 state_dict,其包含了优化器的状态信息,以及所使用的超参数.

    由于 state_dict 对象时 Python 字典的形式,因此,便于保存,更新,修改与恢复,有利于 PyTorch 模型和优化器的模块化.

    例如,Training a classifier tutorial 中所使用的简单模型的 state_dict

    # 模型定义
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    
    class ModelNet(nn.Module):
        def __init__(self):
            super(ModelNet, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    # 模型初始化
    model = ModelNet()
    
    # 优化器初始化
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    # 打印模型的 state_dict
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, 
              "	", 
              model.state_dict()[param_tensor].size()
             )
    
    # Print optimizer's state_dict
    print("Optimizer's state_dict:")
    for var_name in optimizer.state_dict():
        print(var_name, "	", optimizer.state_dict()[var_name])

    输出如下:

    Model's state_dict:
    conv1.weight      torch.Size([6, 3, 5, 5])
    conv1.bias      torch.Size([6])
    conv2.weight      torch.Size([16, 6, 5, 5])
    conv2.bias      torch.Size([16])
    fc1.weight      torch.Size([120, 400])
    fc1.bias      torch.Size([120])
    fc2.weight      torch.Size([84, 120])
    fc2.bias      torch.Size([84])
    fc3.weight      torch.Size([10, 84])
    fc3.bias      torch.Size([10])
    
    Optimizer's state_dict:
    param_groups      [{'weight_decay': 0, 
                       'dampening': 0, 
                       'params': [140448775121872, 140448775121728,
                                  140448775121584, 140448775121440,
                                  140448775121296, 140448775121152,
                                  140448775121008, 140448775120864,
                                  140448775120720, 140448775120576],
                       'nesterov': False, 
                       'momentum': 0.9, 
                       'lr': 0.001}]
    state      {}

    3.模型的保存与加载

    (1)保存/加载 state_dict (推荐)

    # 模型保存
    torch.save(model.state_dict(), PATH)
    
    # 模型加载
    model = ModelNet(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()

    当保存模型,用于推断时,只有训练的模型可学习参数是有必要进行保存的.

    采用 torch.save() 函数保存模型的 state_dict,对于应用时,模型恢复具有最好的灵活性,因此推荐采用该方式进行模型保存.

    PyTorch 通用模型保存格式为 .pt 和 .pth 文件扩展名形式.

    需要注意的时,在运行推断前,需要调用 model.eval() 函数,以将 dropout 层 和 batch normalization 层设置为评估模式(非训练模式).

    注意:

    load_state_dict() 函数的输入是字典形式,而不是对象保存的文件路径.

    也就是说,在将保存的模型文件送入 load_state_dict() 函数前,必须将保存的 state_dict 进行反序列化.

    例如,不能直接应用 model.load_state_dict(PATH),而是,load_state_dict(torch.load(PATH)).

    (2)保存/加载全部模型信息

    # 保存
    torch.save(model, PATH)
    
    # 加载
    model = ModelNet(*args, **kwargs) # 必须预先定义过模型.
    model = torch.load(PATH)
    model.eval()

    这种方式是最直观的语法,包含最少的代码. 其会采用 Python 的pickle 模块保存全部的模型模块.

    这种方式的缺点在于,序列化的数据受限于在模型保存时所采用的特定的类和准确的路径结构(specific classes and the exact directory structure). 其原因是,because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. 因此,再加载后经过许多重构后,或在其它项目中使用时,可能会被打乱.

    参考文献:https://www.aiuai.cn/aifarm743.html

  • 相关阅读:
    ubuntu 14.04搭建PHP项目基本流程
    linux下 lvm 磁盘扩容
    LVM基本介绍与常用命令
    Linux LVM逻辑卷配置过程详解
    mysql 5.7中的用户权限分配相关解读!
    linux系统维护时的一些小技巧,包括系统挂载新磁盘的方法!可收藏!
    linux系统内存爆满的解决办法!~
    源、更新源时容易出现的问题解决方法
    NV显卡Ubuntu14.04更新软件导致登录死循环,不过可以进入tty模式
    一些要注意的地方
  • 原文地址:https://www.cnblogs.com/luckyplj/p/13530908.html
Copyright © 2011-2022 走看看