zoukankan      html  css  js  c++  java
  • Pytorch-修改预训练参数

    我自己改进的模型为model(model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)),原模型为resnet50。

    1.查看模型参数

    现模型:

    1 model_dict = model.state_dict()
    2 for k,v in model_dict.items():
    3     print(k)

    预训练模型参数

    1 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    2 for k,v in pretrained_dict.items():
    3     print(k)

    2.将预训练参数赋给自己改进的模型

    改进的模型参数和原模型参数一致时:

    1 import torch.utils.model_zoo as model_zoo
    2 
    3 model_urls = {
    4     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
    5 }
    6 
    7 model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)   

    Tip:如果两个模型参数完全一致的话,strict=True,如果两个模型参数不一致的话,当strict=False预训练模型会把具有相同参数名称的值赋给改进的参数,不相同的则不赋值。

    改进的模型参数和原模型参数不一致时,使用部分预训练模型参数初始化网络 :

    1 model_dict = model.state_dict()          #取出自己模型的网络参数 
    2 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    3 
    4 model_dict['classifiers.3.fc.weight'] = pretrained_dict['fc.weight'][:2]
    5 model_dict['classifiers.3.fc.bias'] = pretrained_dict['fc.bias'][:2]
  • 相关阅读:
    复利计算
    实验四 主存空间的分配和回收
    0526 Sprint1个人总结 & 《构建之法》第八、九、十章
    实验三 进程调度模拟程序
    0427 scrum & 读后感
    0415 评论
    0414 结对2.0
    汉堡包
    0406 结对编程总结
    读《构建之法》第四章有感
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13893065.html
Copyright © 2011-2022 走看看