zoukankan      html  css  js  c++  java
  • 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    在前面的例子中,基本上都是将每一层的输出直接作为下一层的输入,这种网络称为前馈传播网络(feedforward neural network)。对于此类网络如果每次都写复杂的forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。下面举例说明。

    一、nn.Sequential()对象

    nn.Sequential()对象是类似keras的前馈模型的对象,可以为之添加层实现前馈神经网络。

    1、模型建立方式

    第一种写法:

    nn.Sequential()对象.add_module(层名,层class的实例)

    net1 = nn.Sequential()
    net1.add_module('conv', nn.Conv2d(3, 3, 3))
    net1.add_module('batchnorm', nn.BatchNorm2d(3))
    net1.add_module('activation_layer', nn.ReLU())
    

     第二种写法:

    nn.Sequential(*多个层class的实例)

    net2 = nn.Sequential(
            nn.Conv2d(3, 3, 3),
            nn.BatchNorm2d(3),
            nn.ReLU()
            )
    

    第三种写法:

    nn.Sequential(OrderedDict([*多个(层名,层class的实例)]))

    from collections import OrderedDict
    net3= nn.Sequential(OrderedDict([
              ('conv', nn.Conv2d(3, 3, 3)),
              ('batchnorm', nn.BatchNorm2d(3)),
              ('activation_layer', nn.ReLU())
            ]))
    

    2、检查以及调用模型

    查看模型

    print对象即可

    print('net1:', net1)
    print('net2:', net2)
    print('net3:', net3)
    
    net1: Sequential(
      (conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
      (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
      (activation_layer): ReLU()
    )
    net2: Sequential(
      (0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
      (2): ReLU()
    )
    net3: Sequential(
      (conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
      (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
      (activation_layer): ReLU()
    )

    提取子Module对象

    # 可根据名字或序号取出子module
    net1.conv, net2[0], net3.conv
    
    (Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
     Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
     Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)))

    调用模型

    可以直接网络对象(输入数据),也可以使用上面的Module子对象分别传入(input)。

    input = V(t.rand(1, 3, 4, 4))
    output = net1(input)
    output = net2(input)
    output = net3(input)
    output = net3.activation_layer(net1.batchnorm(net1.conv(input)))
    

    二、nn.ModuleList()对象

    ModuleListModule的子类,当在Module中使用它的时候,就能自动识别为子module。

    建立以及使用方法如下,

    modellist = nn.ModuleList([nn.Linear(3,4), nn.ReLU(), nn.Linear(4,2)])
    input = V(t.randn(1, 3))
    for model in modellist:
        input = model(input)
    # 下面会报错,因为modellist没有实现forward方法
    # output = modelist(input)
    

    和普通list不一样,它和torch的其他机制结合紧密,继承了nn.Module的网络模型class可以使用nn.ModuleList并识别其中的parameters,当然这只是个list,不会自动实现forward方法,

    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.list = [nn.Linear(3, 4), nn.ReLU()]
            self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()])
        def forward(self):
            pass
    model = MyModule()
    print(model)
    
    MyModule(
      (module_list): ModuleList(
        (0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
      )
    )
    for name, param in model.named_parameters():
        print(name, param.size())
    
    ('module_list.0.weight', torch.Size([3, 3, 3, 3]))
    ('module_list.0.bias', torch.Size([3]))
    

    可见,list中的子module并不能被主module所识别,而ModuleList中的子module能够被主module所识别。这意味着如果用list保存子module,将无法调整其参数,因其未加入到主module的参数中。

    除ModuleList之外还有ParameterList,其是一个可以包含多个parameter的类list对象。在实际应用中,使用方式与ModuleList类似。如果在构造函数__init__中用到list、tuple、dict等对象时,一定要思考是否应该用ModuleList或ParameterList代替。

  • 相关阅读:
    百度Hi之CSRF蠕虫攻击
    Portlet之讲解
    try-catch语句讲解
    unset之讲解
    MySQL bin-log 日志清理方式
    php数组array_push()和array_pop()以及array_shift()函数
    php中的func_num_args、func_get_arg与func_get_args函数
    PHP is_callable 方法
    如何实现php异步处理
    Mysql并发时经典常见的死锁原因及解决方法
  • 原文地址:https://www.cnblogs.com/hellcat/p/8477195.html
Copyright © 2011-2022 走看看