zoukankan      html  css  js  c++  java
  • pytorch 修改预训练model

        class Net(nn.Module):
            def __init__(self , model):
                super(Net, self).__init__()
                #取掉model的后两层
                self.resnet_layer = nn.Sequential(*list(model.children())[:-2])
                self.transion_layer = nn.ConvTranspose2d(2048, 2048, kernel_size=14, stride=3)
                self.pool_layer = nn.MaxPool2d(32)  
                self.Linear_layer = nn.Linear(2048, 8)
                
            def forward(self, x):
                x = self.resnet_layer(x)
                x = self.transion_layer(x)
                x = self.pool_layer(x)
                x = x.view(x.size(0), -1) 
                x = self.Linear_layer(x) 
                return x
    
    
        resnet = models.resnet50(pretrained=True)
    
        model = Net(resnet)
    

    训练特定层,冻结其它层 

    The basic idea is that all models have a function model.children() which returns it’s layers. Within each layer, there are parameters (or weights), which can be obtained using .param() on any children (i.e. layer). Now, every parameter has an attribute called requires_grad which is by default True. True means it will be backpropagrated and hence to freeze a layer you need to set requires_grad to False for all parameters of a layer.

    import torchvision.models as models
    resnet = models.resnet18(pretrained=True)
    ct = 0
    #This freezes layers 1-6 in the total 10 layers of Resnet18. for child in resnet.children(): ct += 1 if ct< 7: for param in child.parameters(): param.requires_grad = False

      

  • 相关阅读:
    hdu 5087(次长上升子序列)
    hdu 5086(递推)
    hdu 5084(矩阵操作)
    hdu 5083(模拟)
    hdu 5082(水题)
    高数准备:
    ★ phpStudy安装SSL证书实现https链接
    phpStudy环境安装SSL证书教程
    Qt中切换窗口功能的实现
    LeetCode OJ:Reverse Linked List II(反转链表II)
  • 原文地址:https://www.cnblogs.com/ylHe/p/12916055.html
Copyright © 2011-2022 走看看