zoukankan      html  css  js  c++  java
  • Pytorch-修改预训练参数

    我自己改进的模型为model(model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)),原模型为resnet50。

    1.查看模型参数

    现模型:

    1 model_dict = model.state_dict()
    2 for k,v in model_dict.items():
    3     print(k)

    预训练模型参数

    1 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    2 for k,v in pretrained_dict.items():
    3     print(k)

    2.将预训练参数赋给自己改进的模型

    改进的模型参数和原模型参数一致时:

    1 import torch.utils.model_zoo as model_zoo
    2 
    3 model_urls = {
    4     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
    5 }
    6 
    7 model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)   

    Tip:如果两个模型参数完全一致的话,strict=True,如果两个模型参数不一致的话,当strict=False预训练模型会把具有相同参数名称的值赋给改进的参数,不相同的则不赋值。

    改进的模型参数和原模型参数不一致时,使用部分预训练模型参数初始化网络 :

    1 model_dict = model.state_dict()          #取出自己模型的网络参数 
    2 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    3 
    4 model_dict['classifiers.3.fc.weight'] = pretrained_dict['fc.weight'][:2]
    5 model_dict['classifiers.3.fc.bias'] = pretrained_dict['fc.bias'][:2]
  • 相关阅读:
    c++之单链表
    c++之变量的生存期及可见性
    c++之结构体-结构数组排序
    c++之递归函数
    c++之指针练习
    C++之面向对象之构造函数和拷贝构造方法,以及析构方法
    C++之命名空间
    C++之面向对象之对象的使用
    Hadoop RPC实现
    BP(商业计划书写)
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13893065.html
Copyright © 2011-2022 走看看