zoukankan      html  css  js  c++  java
  • pytorch 修改预训练model

        class Net(nn.Module):
            def __init__(self , model):
                super(Net, self).__init__()
                #取掉model的后两层
                self.resnet_layer = nn.Sequential(*list(model.children())[:-2])
                self.transion_layer = nn.ConvTranspose2d(2048, 2048, kernel_size=14, stride=3)
                self.pool_layer = nn.MaxPool2d(32)  
                self.Linear_layer = nn.Linear(2048, 8)
                
            def forward(self, x):
                x = self.resnet_layer(x)
                x = self.transion_layer(x)
                x = self.pool_layer(x)
                x = x.view(x.size(0), -1) 
                x = self.Linear_layer(x) 
                return x
    
    
        resnet = models.resnet50(pretrained=True)
    
        model = Net(resnet)
    

    训练特定层,冻结其它层 

    The basic idea is that all models have a function model.children() which returns it’s layers. Within each layer, there are parameters (or weights), which can be obtained using .param() on any children (i.e. layer). Now, every parameter has an attribute called requires_grad which is by default True. True means it will be backpropagrated and hence to freeze a layer you need to set requires_grad to False for all parameters of a layer.

    import torchvision.models as models
    resnet = models.resnet18(pretrained=True)
    ct = 0
    #This freezes layers 1-6 in the total 10 layers of Resnet18. for child in resnet.children(): ct += 1 if ct< 7: for param in child.parameters(): param.requires_grad = False

      

  • 相关阅读:
    第4章 Android移植环境搭建
    第3章 Android移植平台工具介绍
    第2章
    第1章 Android系统的编译和移植实例:
    nfs
    TFTP服务器搭建
    根系统制作
    nfs挂载
    uboot的编译
    交叉工具链的搭建方法
  • 原文地址:https://www.cnblogs.com/ylHe/p/12916055.html
Copyright © 2011-2022 走看看