zoukankan      html  css  js  c++  java
  • pytorch加载模型

    1.加载全部模型:

    net.load_state_dict(torch.load(net_para_pth))

    2.加载部分模型

    net_para_pth = './result/5826.pth'
    pretrained_dict = torch.load(net_para_pth)
    model_dict = net.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)

    3.改变某一层参数后加载

    将该层名称改一下,然后用2中方法加载,比如,要将conv5的out_channels由256改为512。

    将conv_5改为conv_5_chg,就可以顺利加载了,不改会报错哟

    算是权宜之计了,还有什么好方法,希望多多指教

    
    
  • 相关阅读:
    linux笔记
    ui转化为py
    stl学习
    React第一课
    React 第一课
    创建一个类
    nodejs基本语法
    let和const
    qml_status笔记
    controller层的单元测试
  • 原文地址:https://www.cnblogs.com/jiangnanyanyuchen/p/11796856.html
Copyright © 2011-2022 走看看