zoukankan      html  css  js  c++  java
  • pytorch:修改预训练模型

    torchvision中提供了很多训练好的模型,这些模型是在1000类,224*224的imagenet中训练得到的,很多时候不适合我们自己的数据,可以根据需要进行修改。

    1、类别不同

        # coding=UTF-8  
        import torchvision.models as models  
          
        #调用模型  
        model = models.resnet50(pretrained=True)  
        #提取fc层中固定的参数  
        fc_features = model.fc.in_features  
        #修改类别为9  
        model.fc = nn.Linear(fc_features, 9)  

    2、添加层后,加载部分参数

    model = ...
    model_dict = model.state_dict()
    
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

    参考:https://blog.csdn.net/u012494820/article/details/79068625

              https://blog.csdn.net/whut_ldz/article/details/78845947

  • 相关阅读:
    【每天一道PAT】1001 A+B Format
    C++ STL总结
    开篇
    happen-before原则
    java多线程的状态转换以及基本操作
    集合初始容量
    fail-fast机制
    Stack
    Iterator
    Vector
  • 原文地址:https://www.cnblogs.com/573177885qq/p/8877427.html
Copyright © 2011-2022 走看看