zoukankan      html  css  js  c++  java
  • pytorch加载和保存模型

    在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢?

    方法一(推荐):

    第一种方法也是官方推荐的方法,只保存和恢复模型中的参数。

    保存    

    torch.save(the_model.state_dict(), PATH)

    恢复

    the_model = TheModelClass(*args, **kwargs)
    the_model.load_state_dict(torch.load(PATH))

    使用这种方法,我们需要自己导入模型的结构信息。

    方法二:

    使用这种方法,将会保存模型的参数和结构信息。

    保存

    torch.save(the_model, PATH)

    恢复

    the_model = torch.load(PATH)

    一个相对完整的例子

    saving

    torch.save({
    'epoch': epoch + 1,
    'arch': args.arch,
    'state_dict': model.state_dict(),
    'best_prec1': best_prec1,
    }, 'checkpoint.tar' )

    loading

    if args.resume:
    if os.path.isfile(args.resume):
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)
    args.start_epoch = checkpoint['epoch']
    best_prec1 = checkpoint['best_prec1']
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})"
    .format(args.evaluate, checkpoint['epoch']))
     

    获取模型中某些层的参数

    对于恢复的模型,如果我们想查看某些层的参数,可以:

    # 定义一个网络
    from collections import OrderedDict
    model = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(1,20,5)),
    ('relu1', nn.ReLU()),
    ('conv2', nn.Conv2d(20,64,5)),
    ('relu2', nn.ReLU())
    ]))
    # 打印网络的结构
    print(model)
     
    OUT:
    Sequential (
    (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
    (relu1): ReLU ()
    (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
    (relu2): ReLU ()
    )
     
    如果我们想获取conv1的weight和bias:
     
    params=model.state_dict()
    for k,v in params.items():
    print(k) #打印网络中的变量名
    print(params['conv1.weight']) #打印conv1的weight
    print(params['conv1.bias']) #打印conv1的bias
     


  • 相关阅读:
    启动 YARN 并运行 MapReduce 程序(伪分布式模式)
    启动 HDFS 并运行 MapReduce 程序(伪分布式模式)
    简单计算器(stack)
    Linux定时发邮件脚本
    HttpClient接口调用-客户端
    获取时间字符串
    Visual Assist代码高亮突然失效
    批量快速生成员工文件夹工具
    日语学习笔记整理(汉译日)
    有关使用PLSQL Developer时出现报错ora-12514解决的方法
  • 原文地址:https://www.cnblogs.com/nkh222/p/7656623.html
Copyright © 2011-2022 走看看