zoukankan      html  css  js  c++  java
  • pytorch加载和保存模型

    在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢?

    方法一(推荐):

    第一种方法也是官方推荐的方法,只保存和恢复模型中的参数。

    保存    

    torch.save(the_model.state_dict(), PATH)

    恢复

    the_model = TheModelClass(*args, **kwargs)
    the_model.load_state_dict(torch.load(PATH))

    使用这种方法,我们需要自己导入模型的结构信息。

    方法二:

    使用这种方法,将会保存模型的参数和结构信息。

    保存

    torch.save(the_model, PATH)

    恢复

    the_model = torch.load(PATH)

    一个相对完整的例子

    saving

    torch.save({
    'epoch': epoch + 1,
    'arch': args.arch,
    'state_dict': model.state_dict(),
    'best_prec1': best_prec1,
    }, 'checkpoint.tar' )

    loading

    if args.resume:
    if os.path.isfile(args.resume):
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)
    args.start_epoch = checkpoint['epoch']
    best_prec1 = checkpoint['best_prec1']
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})"
    .format(args.evaluate, checkpoint['epoch']))
     

    获取模型中某些层的参数

    对于恢复的模型,如果我们想查看某些层的参数,可以:

    # 定义一个网络
    from collections import OrderedDict
    model = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(1,20,5)),
    ('relu1', nn.ReLU()),
    ('conv2', nn.Conv2d(20,64,5)),
    ('relu2', nn.ReLU())
    ]))
    # 打印网络的结构
    print(model)
     
    OUT:
    Sequential (
    (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
    (relu1): ReLU ()
    (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
    (relu2): ReLU ()
    )
     
    如果我们想获取conv1的weight和bias:
     
    params=model.state_dict()
    for k,v in params.items():
    print(k) #打印网络中的变量名
    print(params['conv1.weight']) #打印conv1的weight
    print(params['conv1.bias']) #打印conv1的bias
     


  • 相关阅读:
    日常开发常用工具(持续更新中,欢迎小伙伴评论中分享自己认为好用的工具)
    使用 POJO 对象绑定请求参数
    Tomcat+Eclipse乱码问题解决方法
    微信客服接口发消息 -- 微信客服系列文章(一)
    @RequestParam--SpringMVC 注解系列文章(一)
    微信JS图片上传与下载功能--微信JS系列文章(三)
    微信JS分享功能--微信JS系列文章(二)
    微信JS初始化--微信JS系列文章(一)
    二十进制数的加法
    使用NuGet管理项目类库引用
  • 原文地址:https://www.cnblogs.com/nkh222/p/7656623.html
Copyright © 2011-2022 走看看