zoukankan      html  css  js  c++  java
  • [Pytorch]Pytorch 保存模型与加载模型(转)

    转自:知乎

    目录:

    • 保存模型与加载模型
    • 冻结一部分参数,训练另一部分参数
    • 采用不同的学习率进行训练

    1.保存模型与加载

    简单的保存与加载方法:

    # 保存整个网络
    torch.save(net, PATH) 
    # 保存网络中的参数, 速度快,占空间少
    torch.save(net.state_dict(),PATH)
    #--------------------------------------------------
    #针对上面一般的保存方法,加载的方法分别是:
    model_dict=torch.load(PATH)
    model_dict=model.load_state_dict(torch.load(PATH))
    


    然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:

    torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
                                'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
                               checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')
    

    以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定义损失函数的两个参数;格式以字典的格式存储。

    加载的方式:

    def load_checkpoint(model, checkpoint_PATH, optimizer):
        if checkpoint != None:
            model_CKPT = torch.load(checkpoint_PATH)
            model.load_state_dict(model_CKPT['state_dict'])
            print('loading checkpoint!')
            optimizer.load_state_dict(model_CKPT['optimizer'])
        return model, optimizer
    

    其他的参数可以通过以字典的方式获得

    但是,但是,我们可能修改了一部分网络,比如加了一些,删除一些,等等,那么需要过滤这些参数,加载方式:

    def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
        if checkpoint != 'No':
            print("loading checkpoint...")
            model_dict = model.state_dict()
            modelCheckpoint = torch.load(checkpoint)
            pretrained_dict = modelCheckpoint['state_dict']
            # 过滤操作
            new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
            model_dict.update(new_dict)
            # 打印出来,更新了多少的参数
            print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
            model.load_state_dict(model_dict)
            print("loaded finished!")
            # 如果不需要更新优化器那么设置为false
            if loadOptimizer == True:
                optimizer.load_state_dict(modelCheckpoint['optimizer'])
                print('loaded! optimizer')
            else:
                print('not loaded optimizer')
        else:
            print('No checkpoint is included')
        return model, optimizer
    

    2.冻结部分参数,训练另一部分参数

    1)添加下面一句话到模型中

    for p in self.parameters():
        p.requires_grad = False
    

    比如加载了resnet预训练模型之后,在resenet的基础上连接了新的模快,resenet模块那部分可以先暂时冻结不更新,只更新其他部分的参数,那么可以在下面加入上面那句话

    class RESNET_MF(nn.Module):
        def __init__(self, model, pretrained):
            super(RESNET_MF, self).__init__()
            self.resnet = model(pretrained)
            for p in self.parameters():
                p.requires_grad = False
            self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
            self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
            self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
            ...
    

    同时在优化器中添加:filter(lambda p: p.requires_grad, model.parameters())

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999),
                                   eps=1e-08, weight_decay=1e-5)
    

    2) 参数保存在有序的字典中,那么可以通过查找参数的名字对应的id值,进行冻结

    查找的代码:

        model_dict = torch.load('net.pth.tar').state_dict()
        dict_name = list(model_dict)
        for i, p in enumerate(dict_name):
            print(i, p)
    

    保存一下这个文件,可以看到大致是这个样子的:

    0 gamma
    1 resnet.conv1.weight
    2 resnet.bn1.weight
    3 resnet.bn1.bias
    4 resnet.bn1.running_mean
    5 resnet.bn1.running_var
    6 resnet.layer1.0.conv1.weight
    7 resnet.layer1.0.bn1.weight
    8 resnet.layer1.0.bn1.bias
    9 resnet.layer1.0.bn1.running_mean
    ....
    

    同样在模型中添加这样的代码:

    for i,p in enumerate(net.parameters()):
        if i < 165:
            p.requires_grad = False
    

    在优化器中添加上面的那句话可以实现参数的屏蔽

  • 相关阅读:
    Selenium简单测试页面加载速度的性能(Page loading performance)
    Selenium Page object Pattern usage
    Selenium如何支持测试Windows application
    UI Automation的两个成熟的框架(QTP 和Selenium)
    分享自己针对Automation做的两个成熟的框架(QTP 和Selenium)
    敏捷开发中的测试金字塔(转)
    Selenium 的基础框架类
    selenium2 run in Jenkins GUI testing not visible or browser not open but run in background浏览器后台运行不可见
    eclipse与SVN 结合(删除SVN中已经上传的问题)
    配置Jenkins的slave节点的详细步骤适合windows等其他平台
  • 原文地址:https://www.cnblogs.com/kk17/p/10074188.html
Copyright © 2011-2022 走看看