zoukankan      html  css  js  c++  java
  • pytorch 修改预训练model

        class Net(nn.Module):
            def __init__(self , model):
                super(Net, self).__init__()
                #取掉model的后两层
                self.resnet_layer = nn.Sequential(*list(model.children())[:-2])
                self.transion_layer = nn.ConvTranspose2d(2048, 2048, kernel_size=14, stride=3)
                self.pool_layer = nn.MaxPool2d(32)  
                self.Linear_layer = nn.Linear(2048, 8)
                
            def forward(self, x):
                x = self.resnet_layer(x)
                x = self.transion_layer(x)
                x = self.pool_layer(x)
                x = x.view(x.size(0), -1) 
                x = self.Linear_layer(x) 
                return x
    
    
        resnet = models.resnet50(pretrained=True)
    
        model = Net(resnet)
    

    训练特定层,冻结其它层 

    The basic idea is that all models have a function model.children() which returns it’s layers. Within each layer, there are parameters (or weights), which can be obtained using .param() on any children (i.e. layer). Now, every parameter has an attribute called requires_grad which is by default True. True means it will be backpropagrated and hence to freeze a layer you need to set requires_grad to False for all parameters of a layer.

    import torchvision.models as models
    resnet = models.resnet18(pretrained=True)
    ct = 0
    #This freezes layers 1-6 in the total 10 layers of Resnet18. for child in resnet.children(): ct += 1 if ct< 7: for param in child.parameters(): param.requires_grad = False

      

  • 相关阅读:
    Web开发中的显示与隐藏
    Html中的表格
    go标准库的学习-encoding/json
    go-simplejson文档学习
    go标准库的学习-regexp
    go标准库的学习-net
    go标准库的学习-strconv-字符串转换
    go标准库的学习-strings-字符串操作
    go标准库的学习-net/rpc
    go标准库的学习-net/rpc/jsonrpc
  • 原文地址:https://www.cnblogs.com/ylHe/p/12916055.html
Copyright © 2011-2022 走看看