zoukankan      html  css  js  c++  java
  • pytorch加载和保存模型

    在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢?

    方法一(推荐):

    第一种方法也是官方推荐的方法,只保存和恢复模型中的参数。

    保存    

    torch.save(the_model.state_dict(), PATH)

    恢复

    the_model = TheModelClass(*args, **kwargs)
    the_model.load_state_dict(torch.load(PATH))

    使用这种方法,我们需要自己导入模型的结构信息。

    方法二:

    使用这种方法,将会保存模型的参数和结构信息。

    保存

    torch.save(the_model, PATH)

    恢复

    the_model = torch.load(PATH)

    一个相对完整的例子

    saving

    torch.save({
    'epoch': epoch + 1,
    'arch': args.arch,
    'state_dict': model.state_dict(),
    'best_prec1': best_prec1,
    }, 'checkpoint.tar' )

    loading

    if args.resume:
    if os.path.isfile(args.resume):
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)
    args.start_epoch = checkpoint['epoch']
    best_prec1 = checkpoint['best_prec1']
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})"
    .format(args.evaluate, checkpoint['epoch']))
     

    获取模型中某些层的参数

    对于恢复的模型,如果我们想查看某些层的参数,可以:

    # 定义一个网络
    from collections import OrderedDict
    model = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(1,20,5)),
    ('relu1', nn.ReLU()),
    ('conv2', nn.Conv2d(20,64,5)),
    ('relu2', nn.ReLU())
    ]))
    # 打印网络的结构
    print(model)
     
    OUT:
    Sequential (
    (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
    (relu1): ReLU ()
    (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
    (relu2): ReLU ()
    )
     
    如果我们想获取conv1的weight和bias:
     
    params=model.state_dict()
    for k,v in params.items():
    print(k) #打印网络中的变量名
    print(params['conv1.weight']) #打印conv1的weight
    print(params['conv1.bias']) #打印conv1的bias
     


  • 相关阅读:
    Hello,world的几种写法!
    浮动与清除浮动
    css中表格的table-layout属性特殊用法
    CSS之照片集效果
    CSS之transition过渡练习
    CSS之过渡简单应用—日落西山
    CSS之立方体绘画步骤
    CSS之立体球体
    transform
    Vue.sync修饰符与this.$emit('update:xxx', newXXX)
  • 原文地址:https://www.cnblogs.com/nkh222/p/7656623.html
Copyright © 2011-2022 走看看