zoukankan      html  css  js  c++  java
  • Pytorch中的nn.Sequential

    A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.

    一个有序的容器,神经网络模块(module)将按照在传入构造器时的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典(OrderedDict)也可以作为传入参数。

    # Example of using Sequential
            model = nn.Sequential(
                      nn.Conv2d(1,20,5),
                      nn.ReLU(),
                      nn.Conv2d(20,64,5),
                      nn.ReLU()
                    )
    
            # Example of using Sequential with OrderedDict
            model = nn.Sequential(OrderedDict([
                      ('conv1', nn.Conv2d(1,20,5)),
                      ('relu1', nn.ReLU()),
                      ('conv2', nn.Conv2d(20,64,5)),
                      ('relu2', nn.ReLU())
                    ]))

    接下来看一下Sequential源码,是如何实现的:
    https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential
    先看一下初始化函数__init__,在初始化函数中,首先是if条件判断,如果传入的参数为1个,并且类型为OrderedDict,通过字典索引的方式将子模块添加到self._module中,否则,通过for循环遍历参数,将所有的子模块添加到self._module中。注意,Sequential模块的初始换函数没有异常处理,所以在写的时候要注意,注意,注意了

        def __init__(self, *args):
            super(Sequential, self).__init__()
            if len(args) == 1 and isinstance(args[0], OrderedDict):
                for key, module in args[0].items():
                    self.add_module(key, module)
            else:
                for idx, module in enumerate(args):
                    self.add_module(str(idx), module)

    接下来在看一下forward函数的实现:
    因为每一个module都继承于nn.Module,都会实现__call__forward函数,具体讲解点击这里,所以forward函数中通过for循环依次调用添加到self._module中的子模块,最后输出经过所有神经网络层的结果:

        def forward(self, input):
            for module in self:
                input = module(input)
            return input

    下面是简单的三层网络结构的例子:

    # hyper parameters
    in_dim=1
    n_hidden_1=1
    n_hidden_2=1
    out_dim=1
    
    class Net(nn.Module):
        def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
            super().__init__()
    
              self.layer = nn.Sequential(
                nn.Linear(in_dim, n_hidden_1), 
                nn.ReLU(True),
                nn.Linear(n_hidden_1, n_hidden_2),
                nn.ReLU(True),
                # 最后一层不需要添加激活函数
                nn.Linear(n_hidden_2, out_dim)
                 )
    
          def forward(self, x):
              x = self.layer(x)
              return x

    上面的代码就是通过Squential将网络层和激活函数结合起来,输出激活后的网络节点。


    原文链接:https://blog.csdn.net/dss_dssssd/java/article/details/82980222

  • 相关阅读:
    幸福婚姻的八个公式
    中国移动增值服务的现状及趋势
    移动通信与互联网融合已成为趋势
    项目管理入门
    4GMF论坛主席卢伟谈4G全球发展概况
    吴刚实践总结手机网游十大金科玉律
    项目管理五大过程组在通信工程中的运用实例
    报告称近半WAP用户低学历 学生工人上网最多
    手机杂志:成长的烦恼
    3G门户网
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/13037242.html
Copyright © 2011-2022 走看看