1.三个核心函数
介绍一系列关于 PyTorch 模型保存与加载的应用场景,主要包括三个核心函数:
(1)torch.save
其中,应用了 Python 的 pickle 包,进行序列化,可适用于模型Models,张量Tensors,以及各种类型的字典对象的序列化保存.
(2)torch.load
采用 Python 的 pickle 的 unpickling 函数,对磁盘 pickled 的对象文件进行反序列化(deserialize),加载到内存.
(3)torch.nn.Module.load_state_dict
采用序列化的 state_dict
加载模型参数(字典).
2.state_dict介绍
PyTorch中,torch.nn.Module
模型中的可学习参数(learnable parameters)(如,weights 和 biases),包含在模型参数(model parameters)里(根据 model.parameters()
进行访问.)
state_dict
可以简单的理解为 Python 的字典对象,其将每一层映射到其参数张量.
注,只有包含待学习参数的网络层,如卷积层,线性连接层等,会在模型的 state_dict
中有元素值.
优化器对象(Optimizer object,torch.optim
) 也有 state_dict
,其包含了优化器的状态信息,以及所使用的超参数.
由于 state_dict
对象时 Python 字典的形式,因此,便于保存,更新,修改与恢复,有利于 PyTorch 模型和优化器的模块化.
例如,Training a classifier tutorial 中所使用的简单模型的 state_dict
:
# 模型定义 import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class ModelNet(nn.Module): def __init__(self): super(ModelNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # 模型初始化 model = ModelNet() # 优化器初始化 optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 打印模型的 state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, " ", model.state_dict()[param_tensor].size() ) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, " ", optimizer.state_dict()[var_name])
输出如下:
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: param_groups [{'weight_decay': 0, 'dampening': 0, 'params': [140448775121872, 140448775121728, 140448775121584, 140448775121440, 140448775121296, 140448775121152, 140448775121008, 140448775120864, 140448775120720, 140448775120576], 'nesterov': False, 'momentum': 0.9, 'lr': 0.001}] state {}
3.模型的保存与加载
(1)保存/加载 state_dict
(推荐)
# 模型保存 torch.save(model.state_dict(), PATH) # 模型加载 model = ModelNet(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
当保存模型,用于推断时,只有训练的模型可学习参数是有必要进行保存的.
采用 torch.save()
函数保存模型的 state_dict
,对于应用时,模型恢复具有最好的灵活性,因此推荐采用该方式进行模型保存.
PyTorch 通用模型保存格式为 .pt
和 .pth
文件扩展名形式.
需要注意的时,在运行推断前,需要调用 model.eval()
函数,以将 dropout
层 和 batch normalization
层设置为评估模式(非训练模式).
注意:
load_state_dict()
函数的输入是字典形式,而不是对象保存的文件路径.
也就是说,在将保存的模型文件送入 load_state_dict()
函数前,必须将保存的 state_dict
进行反序列化.
例如,不能直接应用 model.load_state_dict(PATH)
,而是,load_state_dict(torch.load(PATH))
.
(2)保存/加载全部模型信息
# 保存 torch.save(model, PATH) # 加载 model = ModelNet(*args, **kwargs) # 必须预先定义过模型. model = torch.load(PATH) model.eval()
这种方式是最直观的语法,包含最少的代码. 其会采用 Python 的pickle 模块保存全部的模型模块.
这种方式的缺点在于,序列化的数据受限于在模型保存时所采用的特定的类和准确的路径结构(specific classes and the exact directory structure). 其原因是,because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. 因此,再加载后经过许多重构后,或在其它项目中使用时,可能会被打乱.
参考文献:https://www.aiuai.cn/aifarm743.html