zoukankan      html  css  js  c++  java
  • pytorch模型加载方法汇总

    Pytorch有很多方便易用的包,今天要谈的是torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。

    本文,我们讲述的是models,且只谈模型的加载。models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。

    参考地址:https://blog.csdn.net/weixin_41519463/article/details/103205665?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task

    模型地址:https://github.com/pytorch/vision/tree/master/torchvision/models

    官方文档:https://pytorch.org/docs/master/torchvision/models.html

    我将加载的方法简单总结为以下四种:

    1.直接加载预训练模型

    1 import torchvision.models as models
    2  
    3 resnet50 = models.resnet50(pretrained=True)


    这样就导入了resnet50的预训练模型了。

    如果只需要网络结构,不需要用预训练模型的参数来初始化,那么就是:

    model =torchvision.models.resnet50(pretrained=False)

    或者把resnet复制到自己的目录下,新建个model文件夹

    可以参考下面的猫狗大战入门算法入门

    https://github.com/JackwithWilshere/Kaggle-Dogs_vs_Cats_PyTorch

    2.修改某一层
     以resnet为例,默认的是ImageNet的1000类,比如我们要做二分类,分类猫和狗

    1 resnet.fc = nn.Linear(2048, 2) #resnet 第一层卷积的卷积核是7,我们可能想改成5,那么可以通过以下方法修改:
    2 
    3 #未经试验,修改需要有理论依据,计算featuremap维度使之匹配。
    4 resnet.conv1 = nn.Conv2d(3, 64,kernel_size=5, stride=2, padding=3, bias=False)

    3.加载部分预训练模型
    对于具体的任务,很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

     1 #加载model,model是自己定义好的模型
     resnet50 = models.resnet50(pretrained=True) 
     pretrained_dict =resnet50.state_dict() 
     model =Net(...)  4  5 #读取参数   6  
    model_dict = model.state_dict()

    9 #将pretrained_dict里不属于model_dict的键剔除掉

    10 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

    # 更新现有的model_dict 13 model_dict.update(pretrained_dict) #这一块更新的什么??

    # 加载我们真正需要的state_dict 16 model.load_state_dict(model_dict)


    4. 加载自己的模型
    其实这个是保存和恢复模型,比如我们训练好的模型保存,然后加载用于测试。

    方法一(推荐):

    第一种方法也是官方推荐的方法,只保存和恢复模型中的参数(权重数值)。

    使用这种方法,我们需要自己导入模型的结构信息。

    (1)保存

    1 torch.save(model.state_dict(), PATH)
    2  
    3 #example
    4 torch.save(resnet50.state_dict(),'ckp/model.pth')    

    (2)恢复

    1 model = ModelClass(*args, **kwargs)
    2 model.load_state_dict(torch.load(PATH))
    3  
    4 #example
    5 resnet=resnet50(pretrained=True)
    6 resnet.load_state_dict(torch.load('ckp/model.pth'))


    方法二:

    使用这种方法,将会同时保存模型的参数和结构信息到模型文件中。

    (1)保存

    torch.save (the_model, PATH)

    (2)恢复

    torch.load (the_model, PATH)


     

  • 相关阅读:
    CSS3中background-origin和background-clip的区别
    JavaScript的赋值是引用or复制,及参数传递
    写第一个jquery插件实录
    北大acm1008
    北大acm1007
    北大acm1006
    北大acm1005
    北大acm1004
    团队绩效评估
    第二阶段冲刺第十天
  • 原文地址:https://www.cnblogs.com/Henry-ZHAO/p/12725163.html
Copyright © 2011-2022 走看看