zoukankan      html  css  js  c++  java
  • 0602-nn.Module

    0602-nn.Module

    pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

    一、nn.Module

    torch.nn 的核心数据结构就是 Module,它可以看做是某一层,也可以看做是整个神经网络。最常见的做法就是直接继承 nn.module,然后构建自己的网络模型结构。

    1.1 构建一层网络——全连接层

    接下来我们通过使用 nn.Module 实现一个全连接层(仿射层),输出 y 和输入 x 满足 (y=Wx+b),其中 w 和 b 是可学习参数。

    import torch as t
    from torch import nn
    from torch.autograd import Variable as V
    
    class Linear(nn.Module):
        def __init__(self, in_features, out_features):  # 输入的数据维度,输出的数据维度
            super(Linear,
                  self).__init__()  # 等价于 nn.Module.__init__(self),继承父类的init构造函数
            self.w = nn.Parameter(t.randn(in_features, out_features))
            self.b = nn.Parameter(t.randn(out_features))
    
        def forward(self, x):
            x = x.mm(self.w)
            return x + self.b.expand_as(x)
    
    layer = Linear(4, 3)
    input = V(t.randn(2, 4))
    output = layer(input)  # y = Wx + b 的形状是(2,3) = (2,4)*(4*3)+(1,3).expanda_as(x)
    output
    
    tensor([[ 1.1407, -0.1323,  0.3659],
            [ 2.4265, -1.2330, -0.9984]], grad_fn=<AddBackward0>)
    
    for name, parameter in layer.named_parameters():
        print(name, parameter)
    
    w Parameter containing:
    tensor([[-1.3990, -1.9669, -0.0430],
            [ 0.8150,  0.8829, -1.0932],
            [-0.3793,  0.2708,  0.9691],
            [-0.9613, -0.3259,  0.5103]], requires_grad=True)
    b Parameter containing:
    tensor([ 0.9333, -0.7481, -0.6074], requires_grad=True)
    

    从上述代码可以看出实现一个全连接层非常简单,但是需要注意以下几点:

    • 自定义 Linear 必须要继承 nn.Module,并且自定义类的构造函数需要继承 nn.Module 的构造函数
    • 在构造函数中必须自己定义可学习的参数,并且要封装为 Parameter,上述代码则是把 w 和 b 封装成 Parameter,并且可以发现 Parameter 这种数据结构默认 requires_grad=True
    • forward 函数的作用是实现前向传播过程,其输入可以是一个或多个 variable,对 x 的任何操作也必须是 variable 支持的操作
    • 不需要自己写一个反向传播函数,因为它的前向传播都是对 variable 进行操作,nn.Module 能够利用 autograd 自动进行反向传播
    • 调用 layer(input) 时就能得到 input 的结果,其实它的内部是做了 layer.__call__(input) 操作,在 call 函数中,主要调用了 layer.forward(x),另外还对钩子做了一定的处理,因此直接使用 layer(x),而不是使用 layer.forward(x),钩子的具体内容会在接下来讲解。对于 __call__的作用,可以参考这篇文章:详解__call__

    1.2 构建多层网络——多层感知机

    上述只是实现了一个一层网络结构的模型,下面我们通过更复杂的网络——多层感知机,来感受下 Module 的模块真正强大的地方。多层感知机的网络结构如下图所示:

    从多层感知机的网络结构,我们可以看出它由两个全连接层组成,并且它采用 sigimoid 函数作为激活函数。

    class Perceptron(nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            nn.Module.__init__(self)
            self.layer1 = Linear(in_features,
                                 hidden_features)  # 此处的 Linear 是前面定义的全连接层
            self.layer2 = Linear(hidden_features, out_features)
    
        def forward(self, x):
            x = self.layer1(x)
            x = t.sigimoid(x)
            return self.layer2(x)
    
    
    perceptron = Perceptron(3, 4, 1)
    for name, param in perceptron.named_parameters():
        print(name, param.size())
    
    layer1.w torch.Size([3, 4])
    layer1.b torch.Size([4])
    layer2.w torch.Size([4, 1])
    layer2.b torch.Size([1])
    

    从上述代码中,可以看出多层感知机也非常容易,但是也要注意以下两点:

    • 构造函数中,可以利用前面自定义的 Linear 层作为当前 module 对象的一个子 module,并且它的可学习参数也会称为当前 module 的可学习参数,也就是说主 module 可以递归查找子 module 中的 parameter
    • 在前向传播过程中,我们将输出变量都命名为 x,是为了让 Python 回收一些中间层的输出,从而节省内存,但是有些 variable 虽然名字被覆盖,但是由于它在反向传播过程中仍然需要用到,此时 Python 不会回收这部分数据

    对于 parameter的命名有如下规范:

    • 如果没有子模块,parameter 直接命名。例如 self.param_name = nn.Parameter(t.randn(3,4)),则会命名称为 param_name
    • 对于子模块的 parameter,会在它的名字前面加上当前 module 的名字。例如 self.sub_module = SubModel(),SubModel 中也有个名字叫做 param_name 的 parameter,则它的实际名字为 sub_module.param_name

    虽然我们自己定义神经网络的层(layer)看起来不是特别费力,但是 torch 为了让用书使用起来更方便,它对绝大多数的 layer 都做了封装,此处不做延伸,有兴趣的可以去参照官方文档,或者参考这一篇文章:0802_转载-nn模块中的网络层介绍

    阅读上述介绍的文章时,需要注意下面三点:

    • 构造函数的参数,如 nn.Linear(in_features, out_features, bias),需要关注这三个参数的作用
    • 属性、可学习参数和子 module。例如 nn.Linear 中有 weight 和 bias 两个可学习参数,不包含子 module
    • 输入输出的形状,如 nn.linear 的输入形状是 (N,input_features),输出是 (N, output_features),其中 N 是 batch_size

    注:这些自定义的 layer 对输入性状都有一定的假设:输入的不是一个数据,而是一个 batch。如果想要输入一个数据,必须调用 unsqueeze(0) 函数将数据伪装成 batch_size=1 的batch

  • 相关阅读:
    python02
    使用tableau去将存入mysql都地区点击率进行了展示 感觉很好用
    java使用ssh远程操作linux 提交spark jar
    java操作linux 提交spark jar
    spark与kafka集成进行实时 nginx代理 这种sdk埋点 原生日志实时解析 处理
    github开源的一些ip解析 ,运营商信息,经纬度 地址 后续开发使用
    Oracle 并行执行SQL
    Oracle 序列
    Oracle dblink创建
    Oracle Job维护
  • 原文地址:https://www.cnblogs.com/nickchen121/p/14697548.html
Copyright © 2011-2022 走看看