zoukankan      html  css  js  c++  java
  • Pytorch 模型的存储与加载

    Pytorch 模型的存储与加载

    本文主要内容来自Pytorch官方文档推荐的一篇英文博客, 本文主要介绍了在Pytorch中模型的存储方法, 以及存储形式, 以及Pytorch存储模型正真存储的是模型的什么结构. 以及加载模型的时候, 模型的哪些数据会被加载. 以及加载后的形式.

    首先大致讲下三个最主要的函数的功能:

    torch.save: 将序列化的对象存储到硬盘中.此函数使用Python的pickle实用程序进行序列化. 对于数据类型都可以进行序列化存储, 模型, 张量, 以及字典, 等各种数据对象都可以使用该函数存储.

    torch.load: 该函数使用的是 pickle 的阶序列化过程, 并将结果存如内存中, 该函数也促进设备加载数据.

    torch.nn.Module.load_state_dict: 使用反序列化的 state_dict 加载模型的参数字典

    模型的加载

    state_dict 是什么

    在一个Pytorch模型中, 通常是 torch.nn.module , 模型中可学习的参数被包含在模型的参数中, 通常是可以使用 model.parameters() 函数访问, 通常都是使用该方法访问的. state_dict只是一个Python字典对象,它将每个图层映射到其参数张量, 这个字典的 key 是图层的 'name', 注意, 只有该层有可学习的参数的层, 也就是可以通过反向传播优化的层, 以及 registered buffers (batchnorm’s running_mean) 才会在 state_dict 中有存储条目. 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息. state_dict 的本质是对模型进行了字典化.

    state_dict的字典形式使得对模型的操作更加的灵活, 例如直接导出模型, 修改其中的参数信息, 或者对层数进行修改等, 然后继续将模型保留. 还是使用一个简单的模型举个例子:

    class TheModelClass(nn.Module):
        def __init__(self):
            super(TheModelClass, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            # 这里卷积核的大小是 5, 个数是 6, 输入的 width 是 3
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            # 两次卷积的结果应该是 5x5x16 的矩阵
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
            # 可以看出网络层的结构, 两个卷积层, 其余还有全连接层
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    # Initialize model
    model = TheModelClass()
    
    # Initialize optimizer
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    # Print model's state_dict
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, "	", model.state_dict()[param_tensor].size())
    
    # Print optimizer's state_dict
    print("Optimizer's state_dict:")
    for var_name in optimizer.state_dict():
        print(var_name, "	", optimizer.state_dict()[var_name])
    

    可以得到模型的输出为:

    Model's state_dict:
    conv1.weight     torch.Size([6, 3, 5, 5])
    conv1.bias   torch.Size([6])
    conv2.weight     torch.Size([16, 6, 5, 5])
    conv2.bias   torch.Size([16])
    fc1.weight   torch.Size([120, 400])
    fc1.bias     torch.Size([120])
    fc2.weight   torch.Size([84, 120])
    fc2.bias     torch.Size([84])
    fc3.weight   torch.Size([10, 84])
    fc3.bias     torch.Size([10])
    
    Optimizer's state_dict:
    state    {}
    param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
    

    模型的参数的输出是字典的键值对, 后面是优化参数的输出, 也是键值对

    存储与加载模型对应的形式

    使用 state_dict 存储与加载模型

    save:

    torch.save(model.state_dict(), PATH)
    

    Load 模型:

    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
    

    从模型存储的角度, 存储模型的时候, 唯一需要存储的是该模型训练的参数, torch.save() 函数也可以存储模型的 state_dict. 使用该方法进行存储, 模型被看做字典形式, 所以对模型的操作更加灵活. 在这种形式下常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型.

    注意, 加载模型之后, 并不能直接运行, 需要使用 model.eval() 函数设置 Dropout 与层间正则化. 另一方面, 该方法在存储模型的时候是以字典的形式存储的, 也就是存储的是模型的字典数据, Pytorch 不能直接将模型读取为该形式, 必须先 torch.load() 该模型, 然后再使用 load_state_dict().

    将模型作为整体存储与加载

    Save:

    torch.save(model, PATH)
    

    Load:

    # Model class must be defined somewhere
    model = torch.load(PATH)
    model.eval()
    

    使用该方法相当于跳过了对模型的 state_dict 描述的过程, 而是直接使用 python 的 pickle 包, 这种方法的缺点是, 模型的存储形式与加载形式十分固定, 这样做的原因是因为pickle不会保存模型类本身. 而是存出来包含该文件的路径,该路径在加载时使用. 因此,在其他项目中使用或重构后,代码可能会以各种方式中断. 但是这种方法存储的文件的类型与前面的方法一样. 同样, 以该方法加载模型运行之前需要调用 model.eval() .

    存储与加载一般的 Checkpoint

    Save:

    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                ...
                }, PATH)
    

    Load:

    model = TheModelClass(*args, **kwargs)
    optimizer = TheOptimizerClass(*args, **kwargs)
    
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    model.eval()
    # - or -
    model.train()
    

    可以看出 checkpoints 是模型主要内容的一个字典, 基本包含了模型各种数据, 例如上面的例子模型的参数使用的是 optimizer.state_dict().

    存储 checkpoints 主要目的是为了方便加载模型继续训练, 将所有的信息存储, 加载模型继续训练的时候就会更加方便. 为了存储一个训练过程的多种信息, 最好的方式是使用 dictionary 进行序列化, 这样存储一个训练模型的形式是 .tar, 要加载项目,首先初始化模型和优化器,然后使用torch.load() 在本地加载字典.从这里开始, 只需按期望查询字典即可轻松访问已保存的项目. 请记住,在运行推理之前,必须调用model.eval() 来将 Dropout 和 Batch 正则化设置为评估模式, 不这样做将产生不一致的推断结果. 如果恢复训练,那么调用model.train() 以确保这些层处于训练模式.

    在一个文件中存储多个模型

    save:

    torch.save({
                'modelA_state_dict': modelA.state_dict(),
                'modelB_state_dict': modelB.state_dict(),
                'optimizerA_state_dict': optimizerA.state_dict(),
                'optimizerB_state_dict': optimizerB.state_dict(),
                ...
                }, PATH)
    

    Load:

    modelA = TheModelAClass(*args, **kwargs)
    modelB = TheModelBClass(*args, **kwargs)
    optimizerA = TheOptimizerAClass(*args, **kwargs)
    optimizerB = TheOptimizerBClass(*args, **kwargs)
    
    checkpoint = torch.load(PATH)
    modelA.load_state_dict(checkpoint['modelA_state_dict'])
    modelB.load_state_dict(checkpoint['modelB_state_dict'])
    optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
    optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
    
    modelA.eval()
    modelB.eval()
    # - or -
    modelA.train()
    modelB.train()
    

    保存包含多个 torch.nn.Modules 的模型(例如GAN,序列到序列模型或模型集合)时,将采用与保存常规检查点相同的方法。 换句话说,保存每个模型的state_dict和相应的优化器的字典. 如前所述,您可以保存任何其他可以帮助您恢复培训的项目,只需将它们添加到字典中即可. 使用该方法存储的文件也是 .tar 形式的, 要加载模型,请首先初始化模型和优化器,然后使用torch.load()在本地加载字典。 从这里,您只需按期望查询字典即可轻松访问已保存的项目.

    跨平台模型保存与加载

    GPU 到 CPU

    Save:

    torch.save(model.state_dict(), PATH)
    

    Load:

    device = torch.device('cpu')
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location=device))
    
    Save on GPU, Load on GPU

    Save:

    torch.save(model.state_dict(), PATH)
    

    Load:

    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.to(device)
    # Make sure to call input = input.to(device) on any input tensors that you feed to the model
    
    Save on CPU, Load on GPU

    Save:

    torch.save(model.state_dict(), PATH)
    

    Load:

    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
    model.to(device)
    # Make sure to call input = input.to(device) on any input tensors that you feed to the model
    

    加载部分模型

    举个例子:

    # build
            encoder = TransformerModel(params, dico, is_encoder=True, with_output=False)  # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0
            decoder = TransformerModel(params, dico, is_encoder=False, with_output=True)
            
    # reload pretrained word embeddings
            if params.reload_emb != '':
            	# 表示加载预训练模型
                word2id, embeddings = load_embeddings(params.reload_emb, params)
                set_pretrain_emb(encoder, dico, word2id, embeddings)
                set_pretrain_emb(decoder, dico, word2id, embeddings)
                set_pretrain_emb(model2, dico, word2id, embeddings)
    
            # reload a pretrained model
            if params.reload_model != '':
                enc_path, dec_path = params.reload_model.split(',')
                assert not (enc_path == '' and dec_path == '')
    
                # reload encoder
                if enc_path != '':
                    enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
                    # 预训练模型是在 GPU 上训练的
                    enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
                    # 导入存储的文件的模型
                    if all([k.startswith('module.') for k in enc_reload.keys()]):
                        enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
                    # 这个过程相当于将model 反序列化为 state_dict的形式
                    encoder.load_state_dict(enc_reload, strict=False)
                    # 这个后面的 strict=False 就是对 encoder 与 enc_reload.state_dict之间差异进行处理, 如果encoder 的模型结构与 enc_reload模型结构
                    # 不一样的时候, 就会向 encoder 转化, 也就是 encoder 不包含的层就不会导入, 例如这里 enc_reload 就是一个完整的 Transformer 模型, 但是
                    # encoder 是不包含输出部分的, 所以就不会加载这部分
    

    对于该部分, 本文只是做了个简单的例子介绍, 更详细的内容参见传送门 . 对于这个传送门的例子, 如果我们先存储一个大模型, 将大模型加载到小模型的时候, 使用:

    path = 'xxx.pth'
    model = Net()
    model.load_state_dict(t.load(path), strict=False)
    for module in model.named_modules():
        print(module)
    for name, param in model.named_parameters():
        print(name, param)
    

    从输出可以看出模型向下兼容,

  • 相关阅读:
    Android源码分析(二)-----如何编译修改后的framework资源文件
    Android源码分析(一)-----如何快速掌握Android编译文件
    AI2(App Inventor 2)离线版服务器网络版
    AI2(App Inventor 2)离线版服务器单机版
    AI2(App Inventor 2)离线版服务器(2019.04.28更新)
    解释器模式
    迭代器模式
    备忘录模式
    访问者模式
    命令模式
  • 原文地址:https://www.cnblogs.com/wevolf/p/12918217.html
Copyright © 2011-2022 走看看