zoukankan      html  css  js  c++  java
  • pytorch中的add_module函数

    现只讲在自定义网络中add_module的作用。

    总结:

    在自定义网络的时候,由于自定义变量不是Module类型(例如,我们用List封装了几个网络),所以pytorch不会自动注册网络模块add_module函数用来为网络添加模块的,所以我们可以使用这个函数手动添加自定义的网络模块。当然,这种情况,我们也可以使用ModuleList来封装自定义模块,pytorch就会自动注册了。

    Let't start!

    add_module函数是在自定义网络添加子模块,例如,当我们自定义一个网络肤过程中,我们既可以

    (1)通过self.module=xxx_module的方式(如下面第3行代码),添加网络模块;

    (2)通过add_module函数对网络中添加模块。

    (3)通过用nn.Sequential对模块进行封装等等。

     1 class NeuralNetwork(nn.Module):
     2     def __init__(self):
     3         super(NeuralNetwork, self).__init__()
     4         self.layers = nn.Linear(28*28,28*28)
     5 #         self.add_module('layers',nn.Linear(28*28,28*28))  #  跟上面的方式等价
     6         self.linear_relu_stack = nn.Sequential(
     7             nn.Linear(28*28, 512),
     8             nn.ReLU()
     9         )
    10 
    11     def forward(self, x):
    12         for layer in layers:
    13             x = layer(x)
    14         logits = self.linear_relu_stack(x)
    15         return logits

    我们实例化类,然后输出网络的模块看一下:

    1 0 Linear(in_features=784, out_features=784, bias=True)
    2 1 Sequential(
    3   (0): Linear(in_features=784, out_features=512, bias=True)
    4   (1): ReLU()
    5 )

    会发现,上面定义的网络子模块都有:Linear和Sequential。

    但是,有时候pytorch不会自动给我们注册模块,我们需要根据传进来的参数对网络进行初始化,例如:

     1 class NeuralNetwork(nn.Module):
     2     def __init__(self, layer_num):
     3         super(NeuralNetwork, self).__init__()
     4         self.layers = [nn.Linear(28*28,28*28) for _ in range(layer_num)]
     5         self.linear_relu_stack = nn.Sequential(
     6             nn.Linear(28*28, 512),
     7             nn.ReLU()
     8         )
     9 
    10     def forward(self, x):
    11         for layer in layers:
    12             x = layer(x)
    13         logits = self.linear_relu_stack(x)
    14         return logits

    对此我们再初始化一个实例,然后看下网络中的模块:

    1 model = NeuralNetwork(2)
    2 for index,item in enumerate(model.children()):
    3     print(index,item)

    输出结果就是:

    0 Sequential(
      (0): Linear(in_features=784, out_features=512, bias=True)
      (1): ReLU()
    ) 

    你会发现定义的Linear模块都不见了,而上面定义的时候,明明都制订了。这是因为pytorch在注册模块的时候,会查看成员的类型,如果成员变量类型是Module的子类,那么pytorch就会注册这个模块,否则就不会。

    这里的self.layers是python中的List类型,所以不会自动注册,那么就需要我们再定义后,手动注册(下图黄色标注部分):

     1 class NeuralNetwork(nn.Module):
     2     def __init__(self, layer_num):
     3         super(NeuralNetwork, self).__init__()
     4         self.layers = [nn.Linear(28*28,28*28) for _ in range(layer_num)]
     5         for i,layer in enumerate(self.layers):
     6             self.add_module('layer_{}'.format(i),layer)
     7         self.linear_relu_stack = nn.Sequential(
     8             nn.Linear(28*28, 512),
     9             nn.ReLU()
    10         )
    11 
    12     def forward(self, x):
    13         for layer in layers:
    14             x = layer(x)
    15         logits = self.linear_relu_stack(x)
    16         return logits

    这样我们再输出模型的子模块的时候,就会得到:

    model = NeuralNetwork(4)
    for index,item in enumerate(model.children()):
        print(index,item)
    
    # output
    #0 Linear(in_features=784, out_features=784, bias=True)
    #1 Linear(in_features=784, out_features=784, bias=True)
    #2 Linear(in_features=784, out_features=784, bias=True)
    #3 Linear(in_features=784, out_features=784, bias=True)
    #4 Sequential(
    #  (0): Linear(in_features=784, out_features=512, bias=True)
    #  (1): ReLU()
    #)

    就会看到,已经有了自己注册的模块。

    当然,也可能觉得这种方式比较麻烦,每次都要自己注册下,那能不能有一个类似List的类,在定义的时候就封装一下呢? 

    可以,使用nn.ModuleList封装一下即可达到相同的效果。

    class NeuralNetwork(nn.Module):
        def __init__(self, layer_num):
            super(NeuralNetwork, self).__init__()
            self.layers = nn.ModuleList([nn.Linear(28*28,28*28) for _ in range(layer_num)])
            self.linear_relu_stack = nn.Sequential(
                nn.Linear(28*28, 512),
                nn.ReLU()
            )
    
        def forward(self, x):
            for layer in layers:
                x = layer(x)
            logits = self.linear_relu_stack(x)
            return logits

    参考:
    1. 博客THE PYTORCH ADD_MODULE() FUNCTION link
    2. pytorch 官方文档 中文链接 English version

    如果你喜欢的话...

    如果读完我写的笔记有疑问或者想法,欢迎留下您的评论,我们一起交流、共同讨论、相互学习。如果这篇笔记让您有收获,愿您不吝打赏,您的鼓励是对我最大的肯定,也督促我记录更多质量更好的笔记。

    打赏码
  • 相关阅读:
    内置函数——filter和map
    递归函数
    内置函数和匿名函数
    迭代器和生成器
    装饰器函数
    函数进阶
    COGS 2533. [HZOI 2016]小鱼之美
    COGS 2532. [HZOI 2016]树之美 树形dp
    COGS2531. [HZOI 2016]函数的美 打表+欧拉函数
    bzoj1303: [CQOI2009]中位数图
  • 原文地址:https://www.cnblogs.com/datasnail/p/14903643.html
Copyright © 2011-2022 走看看