zoukankan      html  css  js  c++  java
  • Pytorch-修改预训练参数

    我自己改进的模型为model(model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)),原模型为resnet50。

    1.查看模型参数

    现模型:

    1 model_dict = model.state_dict()
    2 for k,v in model_dict.items():
    3     print(k)

    预训练模型参数

    1 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    2 for k,v in pretrained_dict.items():
    3     print(k)

    2.将预训练参数赋给自己改进的模型

    改进的模型参数和原模型参数一致时:

    1 import torch.utils.model_zoo as model_zoo
    2 
    3 model_urls = {
    4     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
    5 }
    6 
    7 model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)   

    Tip:如果两个模型参数完全一致的话,strict=True,如果两个模型参数不一致的话,当strict=False预训练模型会把具有相同参数名称的值赋给改进的参数,不相同的则不赋值。

    改进的模型参数和原模型参数不一致时,使用部分预训练模型参数初始化网络 :

    1 model_dict = model.state_dict()          #取出自己模型的网络参数 
    2 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    3 
    4 model_dict['classifiers.3.fc.weight'] = pretrained_dict['fc.weight'][:2]
    5 model_dict['classifiers.3.fc.bias'] = pretrained_dict['fc.bias'][:2]
  • 相关阅读:
    c# 反射取其他项目的资源文件
    【分享】免费建立自己的站点
    c# 自定义类型的DataBindings
    ListView 多行拖拽排序
    linq to sql之组装where条件下的'或'语句
    dotfuscator使用方法
    orderBy 传入属性的字符串
    WCF数据交互时长度超过8192
    ASP.net中aspx与cs函数的互调
    c# 读取excel数据的两种方法
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13893065.html
Copyright © 2011-2022 走看看