zoukankan      html  css  js  c++  java
  • [Pytorch]Pytorch 保存模型与加载模型(转)

    转自:知乎

    目录:

    • 保存模型与加载模型
    • 冻结一部分参数,训练另一部分参数
    • 采用不同的学习率进行训练

    1.保存模型与加载

    简单的保存与加载方法:

    # 保存整个网络
    torch.save(net, PATH) 
    # 保存网络中的参数, 速度快,占空间少
    torch.save(net.state_dict(),PATH)
    #--------------------------------------------------
    #针对上面一般的保存方法,加载的方法分别是:
    model_dict=torch.load(PATH)
    model_dict=model.load_state_dict(torch.load(PATH))
    


    然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:

    torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
                                'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
                               checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')
    

    以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定义损失函数的两个参数;格式以字典的格式存储。

    加载的方式:

    def load_checkpoint(model, checkpoint_PATH, optimizer):
        if checkpoint != None:
            model_CKPT = torch.load(checkpoint_PATH)
            model.load_state_dict(model_CKPT['state_dict'])
            print('loading checkpoint!')
            optimizer.load_state_dict(model_CKPT['optimizer'])
        return model, optimizer
    

    其他的参数可以通过以字典的方式获得

    但是,但是,我们可能修改了一部分网络,比如加了一些,删除一些,等等,那么需要过滤这些参数,加载方式:

    def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
        if checkpoint != 'No':
            print("loading checkpoint...")
            model_dict = model.state_dict()
            modelCheckpoint = torch.load(checkpoint)
            pretrained_dict = modelCheckpoint['state_dict']
            # 过滤操作
            new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
            model_dict.update(new_dict)
            # 打印出来,更新了多少的参数
            print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
            model.load_state_dict(model_dict)
            print("loaded finished!")
            # 如果不需要更新优化器那么设置为false
            if loadOptimizer == True:
                optimizer.load_state_dict(modelCheckpoint['optimizer'])
                print('loaded! optimizer')
            else:
                print('not loaded optimizer')
        else:
            print('No checkpoint is included')
        return model, optimizer
    

    2.冻结部分参数,训练另一部分参数

    1)添加下面一句话到模型中

    for p in self.parameters():
        p.requires_grad = False
    

    比如加载了resnet预训练模型之后,在resenet的基础上连接了新的模快,resenet模块那部分可以先暂时冻结不更新,只更新其他部分的参数,那么可以在下面加入上面那句话

    class RESNET_MF(nn.Module):
        def __init__(self, model, pretrained):
            super(RESNET_MF, self).__init__()
            self.resnet = model(pretrained)
            for p in self.parameters():
                p.requires_grad = False
            self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
            self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
            self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
            ...
    

    同时在优化器中添加:filter(lambda p: p.requires_grad, model.parameters())

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999),
                                   eps=1e-08, weight_decay=1e-5)
    

    2) 参数保存在有序的字典中,那么可以通过查找参数的名字对应的id值,进行冻结

    查找的代码:

        model_dict = torch.load('net.pth.tar').state_dict()
        dict_name = list(model_dict)
        for i, p in enumerate(dict_name):
            print(i, p)
    

    保存一下这个文件,可以看到大致是这个样子的:

    0 gamma
    1 resnet.conv1.weight
    2 resnet.bn1.weight
    3 resnet.bn1.bias
    4 resnet.bn1.running_mean
    5 resnet.bn1.running_var
    6 resnet.layer1.0.conv1.weight
    7 resnet.layer1.0.bn1.weight
    8 resnet.layer1.0.bn1.bias
    9 resnet.layer1.0.bn1.running_mean
    ....
    

    同样在模型中添加这样的代码:

    for i,p in enumerate(net.parameters()):
        if i < 165:
            p.requires_grad = False
    

    在优化器中添加上面的那句话可以实现参数的屏蔽

  • 相关阅读:
    Sum Root to Leaf Numbers
    Sum Root to Leaf Numbers
    Sort Colors
    Partition List
    Binary Tree Inorder Traversal
    Binary Tree Postorder Traversal
    Remove Duplicates from Sorted List II
    Remove Duplicates from Sorted List
    Search a 2D Matrix
    leetcode221
  • 原文地址:https://www.cnblogs.com/kk17/p/10074188.html
Copyright © 2011-2022 走看看