zoukankan      html  css  js  c++  java
  • Pytorch: parameters(),children(),modules(),named_*区别

    nn.Module vs nn.functional

    前者会保存权重等信息,后者只是做运算

    parameters()

    返回可训练参数

    nn.ModuleList vs. nn.ParameterList vs. nn.Sequential

    layer_list = [nn.Conv2d(5,5,3), nn.BatchNorm2d(5), nn.Linear(5,2)]
    
    class myNet(nn.Module):
      def __init__(self):
        super().__init__()
        self.layers = layer_list
      
      def forward(x):
        for layer in self.layers:
          x = layer(x)
    
    net = myNet()
    
    print(list(net.parameters()))  # Parameters of modules in the layer_list don't show up.
    

    nn.ModuleList的作用就是wrap pthon list,这样其中的参数会被注册,因此可以返回可训练参数(ParameterList)。

    nn.Sequential的作用如下:

    class myNet(nn.Module):
      def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Relu(inplace=True),
            nn.Linear(10, 10)
        )
      
      def forward(x):
        x = layer(x)
    
    x = torch.rand(10)
    net = myNet()
    print(net(x).shape)
    

    可以看到Sequential的作用就是按照指定的顺序构建网络结构,得到一个完整的模块,而ModuleList则只是像list那样把元素集合起来而已。

    nn.modules vs. nn.children

    class myNet(nn.Module):
      def __init__(self):
        super().__init__()
        self.convBN =  nn.Sequential(nn.Conv2d(10,10,3), nn.BatchNorm2d(10))
        self.linear =  nn.Linear(10,2)
        
      def forward(self, x):
        pass
      
    
    Net = myNet()
    
    print("Printing children
    ------------------------------")
    print(list(Net.children()))
    print("
    
    Printing Modules
    ------------------------------")
    print(list(Net.modules()))
    

    输出信息如下:

    Printing children
    ------------------------------
    [Sequential(
      (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), Linear(in_features=10, out_features=2, bias=True)]
    
    
    Printing Modules
    ------------------------------
    [myNet(
      (convBN1): Sequential(
        (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (linear): Linear(in_features=10, out_features=2, bias=True)
    ), Sequential(
      (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), Linear(in_features=10, out_features=2, bias=True)]
    

    可以看到children只会返回子元素,子元素可能是单个操作,如Linear,也可能是Sequential。 而modules()返回的信息更加详细,不仅会返回children一样的信息,同时还会递归地返回,例如modules()会迭代地返回Sequential中包含的若干个子元素。

    named_*

    • named_parameters: 返回一个iterator,每次它会提供包含参数名的元组。
    In [27]: x = torch.nn.Linear(2,3)
    
    In [28]: x_name_params = x.named_parameters()
    
    In [29]: next(x_name_params)
    Out[29]:
    ('weight', Parameter containing:
     tensor([[-0.5262,  0.3480],
             [-0.6416, -0.1956],
             [ 0.5042,  0.6732]], requires_grad=True))
    
    In [30]: next(x_name_params)
    Out[30]:
    ('bias', Parameter containing:
     tensor([ 0.0595, -0.0386,  0.0975], requires_grad=True))
    
    • named_modules
      这个其实就是把上面提到的nn.modulesiterator的形式返回,每次读取和上面一样也是用next(),示例如下:
    In [46]:  class myNet(nn.Module):                                                          
        ...:    def __init__(self):                                                            
        ...:      super().__init__()                                                           
        ...:      self.convBN1 =  nn.Sequential(nn.Conv2d(10,10,3), nn.BatchNorm2d(10))        
        ...:      self.linear =  nn.Linear(10,2)                                               
        ...:                                                                                   
        ...:    def forward(self, x):                                                          
        ...:      pass                                                                         
        ...:                                                                                   
                                                                                               
    In [47]: net = myNet()                                                                     
                                                                                               
    In [48]: net_named_modules = net.named_modules()                                           
                                                                                               
    In [49]: next(net_named_modules)                                                           
    Out[49]:                                                                                   
    ('', myNet(                                                                                
       (convBN1): Sequential(                                                                  
         (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))                                
         (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  
       )                                                                                       
       (linear): Linear(in_features=10, out_features=2, bias=True)                             
     ))                                                                                        
                                                                                               
    In [50]: next(net_named_modules)                                                           
    Out[50]:                                                                                   
    ('convBN1', Sequential(                                                                    
       (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))                                  
       (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    
     ))                                                                                        
                                                                                               
    In [51]: next(net_named_modules)                                                           
    Out[51]: ('convBN1.0', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)))                  
                                                                                               
    In [52]: next(net_named_modules)                                                           
    Out[52]:                                                                                   
    ('convBN1.1',                                                                              
     BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))          
                                                                                               
    In [53]: next(net_named_modules)                                                           
    Out[53]: ('linear', Linear(in_features=10, out_features=2, bias=True))                     
                                                                                               
    In [54]: next(net_named_modules)                                                           
    ---------------------------------------------------------------------------                
    StopIteration                             Traceback (most recent call last)                
    <ipython-input-54-05e848b071b8> in <module>                                                
    ----> 1 next(net_named_modules)               
    StopIteration:                                                                             
    
    • named_children

    named_modules

    参考

    https://blog.paperspace.com/pytorch-101-advanced/

  • 相关阅读:
    死锁是什么?如何避免死锁?
    HTTP协议 (二) 基本认证
    HTTP协议
    Fiddler 教程
    Wireshark基本介绍和学习TCP三次握手
    洛谷.4512.[模板]多项式除法(NTT)
    洛谷.4238.[模板]多项式求逆(NTT)
    洛谷.3803.[模板]多项式乘法(NTT)
    UOJ.87.mx的仙人掌(圆方树 虚树)(未AC)
    BZOJ.3991.[SDOI2015]寻宝游戏(思路 set)
  • 原文地址:https://www.cnblogs.com/marsggbo/p/11512242.html
Copyright © 2011-2022 走看看