zoukankan      html  css  js  c++  java
  • Pytorch——torch.nn.Sequential()详解

    参考:官方文档    源码

    官方文档

    nn.Sequential

      A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.

      翻译:一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。

    使用方法一:作为一个有顺序的容器

      作为一个有顺序的容器,将特定神经网络模块按照在传入构造器的顺序依次被添加到计算图中执行。

    官方 Example:

    model = nn.Sequential(
              nn.Conv2d(1,20,5),
              nn.ReLU(),
              nn.Conv2d(20,64,5),
              nn.ReLU()
            )
    # When `model` is run,input will first be passed to `Conv2d(1,20,5)`.
    # The output of `Conv2d(1,20,5)` will be used as the input to the first `ReLU`;
    # the output of the first `ReLU` will become the input for `Conv2d(20,64,5)`.
    # Finally, the output of `Conv2d(20,64,5)` will be used as input to the second `ReLU`

    例子:

    net = nn.Sequential(
        nn.Linear(num_inputs, num_hidden)
        # 传入其他层
        )

    使用方法二:作为一个有序字典

      将以特定神经网络模块为元素的有序字典(OrderedDict)为参数传入。

    官方 Example :

    model = nn.Sequential(OrderedDict([
              ('conv1', nn.Conv2d(1,20,5)),
              ('relu1', nn.ReLU()),
              ('conv2', nn.Conv2d(20,64,5)),
              ('relu2', nn.ReLU())
            ]))

    例子:

    net = nn.Sequential()
    net.add_module('linear1', nn.Linear(num_inputs, num_hiddens))
    net.add_module('linear2', nn.Linear(num_hiddens, num_ouputs))
    # net.add_module ......

    源码分析

    初始化函数 init 

        def __init__(self, *args):
            super(Sequential, self).__init__()
            if len(args) == 1 and isinstance(args[0], OrderedDict):
                for key, module in args[0].items():
                    self.add_module(key, module)
            else:
                for idx, module in enumerate(args):
                    self.add_module(str(idx), module)

      __init__ 首先使用 if 条件进行判断,若传入的参数为 1 个,且类型为 OrderedDict,则通过字典索引的方式利用 add_module 函数 将子模块添加到现有模块中。否则,通过 for 循环遍历参数,将所有的子模块添加到现有模块中。 这里需要注意,Sequential 模块的初始换函数没有异常处理。

    forward 函数

        def forward(self, input):
            for module in self:
                input = module(input)
            return input

      因为每一个 module 都继承于 nn.Module,都会实现 __call__ 与 forward 函数,所以 forward 函数中通过 for 循环依次调用添加到 self._module 中的子模块,最后输出经过所有神经网络层的结果。

    因上求缘,果上努力~~~~ 作者:每天卷学习,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/15427842.html

  • 相关阅读:
    数据库分区与分表
    Paxos算法简单介绍
    Zookeeper实现分布式锁服务(Chubby)
    java.lang.OutOfMemoryError: Java heap space错误及处理办
    关于分布式事务、两阶段提交协议、三阶提交协议
    Volatile
    寻找数组中只出现一次的数
    堆排序
    二叉树遍历 递归非递归
    redis 数据类型
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15427842.html
Copyright © 2011-2022 走看看