zoukankan      html  css  js  c++  java
  • Pytorch-修改预训练参数

    我自己改进的模型为model(model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)),原模型为resnet50。

    1.查看模型参数

    现模型:

    1 model_dict = model.state_dict()
    2 for k,v in model_dict.items():
    3     print(k)

    预训练模型参数

    1 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    2 for k,v in pretrained_dict.items():
    3     print(k)

    2.将预训练参数赋给自己改进的模型

    改进的模型参数和原模型参数一致时:

    1 import torch.utils.model_zoo as model_zoo
    2 
    3 model_urls = {
    4     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
    5 }
    6 
    7 model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)   

    Tip:如果两个模型参数完全一致的话,strict=True,如果两个模型参数不一致的话,当strict=False预训练模型会把具有相同参数名称的值赋给改进的参数,不相同的则不赋值。

    改进的模型参数和原模型参数不一致时,使用部分预训练模型参数初始化网络 :

    1 model_dict = model.state_dict()          #取出自己模型的网络参数 
    2 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    3 
    4 model_dict['classifiers.3.fc.weight'] = pretrained_dict['fc.weight'][:2]
    5 model_dict['classifiers.3.fc.bias'] = pretrained_dict['fc.bias'][:2]
  • 相关阅读:
    SQL语句编写
    触发器
    plot函数中的type中的参数
    【转】R中read.table详解
    7月18日R笔记
    RMySQL在windows下的安装方法
    WinXP下面实现JAVA对R调用 (rJava包设置)
    用R进行文档层次聚类完整实例(tm包)
    R学习之R层次聚类方法(tm包)
    R对term进行层次聚类完整实例(tm包)
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13893065.html
Copyright © 2011-2022 走看看