zoukankan      html  css  js  c++  java
  • pytorch 修改预训练model

        class Net(nn.Module):
            def __init__(self , model):
                super(Net, self).__init__()
                #取掉model的后两层
                self.resnet_layer = nn.Sequential(*list(model.children())[:-2])
                self.transion_layer = nn.ConvTranspose2d(2048, 2048, kernel_size=14, stride=3)
                self.pool_layer = nn.MaxPool2d(32)  
                self.Linear_layer = nn.Linear(2048, 8)
                
            def forward(self, x):
                x = self.resnet_layer(x)
                x = self.transion_layer(x)
                x = self.pool_layer(x)
                x = x.view(x.size(0), -1) 
                x = self.Linear_layer(x) 
                return x
    
    
        resnet = models.resnet50(pretrained=True)
    
        model = Net(resnet)
    

    训练特定层,冻结其它层 

    The basic idea is that all models have a function model.children() which returns it’s layers. Within each layer, there are parameters (or weights), which can be obtained using .param() on any children (i.e. layer). Now, every parameter has an attribute called requires_grad which is by default True. True means it will be backpropagrated and hence to freeze a layer you need to set requires_grad to False for all parameters of a layer.

    import torchvision.models as models
    resnet = models.resnet18(pretrained=True)
    ct = 0
    #This freezes layers 1-6 in the total 10 layers of Resnet18. for child in resnet.children(): ct += 1 if ct< 7: for param in child.parameters(): param.requires_grad = False

      

  • 相关阅读:
    tomcat中配置jmx监控
    常用sql
    String、StringBuffer、StringBuilder的不同使用场景
    求交集的几种方法
    使用liunx部署的心得
    几种有助于开发的注释方式。
    SpringDataJPA的几个使用记录
    今年要完成的几件事
    研究kisso跨域登录的心得
    SpringBoot使用的心得记录
  • 原文地址:https://www.cnblogs.com/ylHe/p/12916055.html
Copyright © 2011-2022 走看看