zoukankan      html  css  js  c++  java
  • nn.moduleList 和Sequential由来、用法和实例 —— 写网络模型

    对于cnn前馈神经网络如果前馈一次写一个forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。下面举例说明。

    目录

    一、nn.Sequential()对象

    1、模型建立方式

    第一种写法:

     第二种写法:

    第三种写法:

    2、检查以及调用模型

    查看模型

    根据名字或序号提取子Module对象

    调用模型

    二、nn.ModuleList()对象

    为什么有他?

    什么时候用?

    和list的区别?

    1. extend和append方法

    2. 建立以及使用方法

    3. yolo v3构建网络

    一、nn.Sequential()对象
    建立nn.Sequential()对象,必须小心确保一个块的输出大小与下一个块的输入大小匹配。基本上,它的行为就像一个nn.Module。

    1、模型建立方式

    第一种写法:
    nn.Sequential()对象.add_module(层名,层class的实例)

    1

    2

    3

    4

    net1 = nn.Sequential()

    net1.add_module('conv', nn.Conv2d(3, 3, 3))

    net1.add_module('batchnorm', nn.BatchNorm2d(3))

    net1.add_module('activation_layer', nn.ReLU())


     第二种写法:
    nn.Sequential(*多个层class的实例)

    1

    2

    3

    4

    5

    net2 = nn.Sequential(

            nn.Conv2d(3, 3, 3),

            nn.BatchNorm2d(3),

            nn.ReLU()

            )


    第三种写法:
    nn.Sequential(OrderedDict([*多个(层名,层class的实例)]))

    1

    2

    3

    4

    5

    6

    from collections import OrderedDict

    net3= nn.Sequential(OrderedDict([

              ('conv', nn.Conv2d(3, 3, 3)),

              ('batchnorm', nn.BatchNorm2d(3)),

              ('activation_layer', nn.ReLU())

            ]))

    2、检查以及调用模型

    查看模型
    print对象即可

    1

    2

    3

    print('net1:', net1)

    print('net2:', net2)

    print('net3:', net3)

    net1: Sequential(
    (conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
    (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
    (activation_layer): ReLU()
    )
    net2: Sequential(
    (0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU()
    )
    net3: Sequential(
    (conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
    (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
    (activation_layer): ReLU()
    )

    根据名字或序号提取子Module对象
    1

    2

    # 可根据名字或序号取出子module

    net1.conv, net2[0], net3.conv

    (Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
    Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
    Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)))

    调用模型
    可以直接网络对象(输入数据),也可以使用上面的Module子对象分别传入(input)。

    1

    2

    3

    4

    5

    input = V(t.rand(1, 3, 4, 4))

    output = net1(input)

    output = net2(input)

    output = net3(input)

    output = net3.activation_layer(net1.batchnorm(net1.conv(input)))

    二、nn.ModuleList()对象
    为什么有他?
    写一个module然后就写foreword函数很麻烦,所以就有了这两个。它被设计用来存储任意数量的nn. module。

    什么时候用?
    如果在构造函数__init__中用到list、tuple、dict等对象时,一定要思考是否应该用ModuleList或ParameterList代替。

    如果你想设计一个神经网络的层数作为输入传递。

    和list的区别?
    ModuleList是Module的子类,当在Module中使用它的时候,就能自动识别为子module。

    当添加 nn.ModuleList 作为 nn.Module 对象的一个成员时(即当我们添加模块到我们的网络时),所有 nn.ModuleList 内部的 nn.Module 的 parameter 也被添加作为 我们的网络的 parameter。

    class MyModule(nn.Module):
    def __init__(self):
    super(MyModule, self).__init__()
    self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
    # ModuleList can act as an iterable, or be indexed using ints
    for i, l in enumerate(self.linears):
    x = self.linears[i // 2](x) + l(x)
    return x
    1. extend和append方法
    nn.moduleList定义对象后,有extend和append方法,用法和python中一样,extend是添加另一个modulelist  append是添加另一个module

    class LinearNet(nn.Module):
    def __init__(self, input_size, num_layers, layers_size, output_size):
    super(LinearNet, self).__init__()

    self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
    self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
    self.linears.append(nn.Linear(layers_size, output_size)
    2. 建立以及使用方法
    建立以及使用方法如下,

    1

    2

    3

    4

    5

    6

    modellist = nn.ModuleList([nn.Linear(3,4), nn.ReLU(), nn.Linear(4,2)])

    input = V(t.randn(1, 3))

    for model in modellist:

        input = model(input)

    # 下面会报错,因为modellist没有实现forward方法

    # output = modelist(input)

    和普通list不一样,它和torch的其他机制结合紧密,继承了nn.Module的网络模型class可以使用nn.ModuleList并识别其中的parameters,当然这只是个list,不会自动实现forward方法,

    1

    2

    3

    4

    5

    6

    7

    8

    9

    class MyModule(nn.Module):

        def __init__(self):

            super(MyModule, self).__init__()

            self.list = [nn.Linear(3, 4), nn.ReLU()]

            self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()])

        def forward(self):

            pass

    model = MyModule()

    print(model)

    MyModule(
    (module_list): ModuleList(
    (0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    )
    )
    1

    2

    for name, param in model.named_parameters():

        print(name, param.size())

    ('module_list.0.weight', torch.Size([3, 3, 3, 3]))
    ('module_list.0.bias', torch.Size([3]))
    可见,普通list中的子module并不能被主module所识别,而ModuleList中的子module能够被主module所识别。这意味着如果用list保存子module,将无法调整其参数,因其未加入到主module的参数中。

    除ModuleList之外还有ParameterList,其是一个可以包含多个parameter的类list对象。在实际应用中,使用方式与ModuleList类似。

    3. yolo v3构建网络
    首先module_list = nn.ModuleList()

    然后 

      for index, x in enumerate(blocks[1:]):#根据不同的block 遍历module

            module = nn.Sequential()

            然后根据cfg读进来的数据,

            module.add_module("batch_norm_{0}".format(index), bn)

            module.add_module("conv_{0}".format(index), conv)

             等等

             module_list.append(module)

     
    ---------------------
    作者:Snoopy_Dream
    来源:CSDN
    原文:https://blog.csdn.net/e01528/article/details/84397174
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    openSUSE 13.1 Milestone 4 发布
    Neo4j 2.0 M4 发布
    iBoxDB for .NET v1.5发布, 移动NoSQL数据库
    GNU libc (Glibc) 2.18 发布
    Android 开源项目维护者宣布退出
    Jeasyframe 开源框架 稳定版 V1.5 发布
    Spring Mobile 1.1.0.RC1 和 1.0.2 发布
    Deis logo 开源PaaS系统 Deis
    EasyCriteria 3.0 发布
    TypeScript 0.9.1 发布,新增 typeof 关键字
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11269588.html
Copyright © 2011-2022 走看看