zoukankan      html  css  js  c++  java
  • 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系

    从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系

    relu多种实现之间的关系

    relu 函数在 pytorch 中总共有 3 次出现:

    1. torch.nn.ReLU()
    2. torch.nn.functional.relu_() torch.nn.functional.relu_()
    3. torch.relu() torch.relu_()

    而这3种不同的实现其实是有固定的包装关系,由上至下是由表及里的过程。

    其中最后一个实际上并不被 pytorch 的官方文档包含,同时也找不到对应的 python 代码,只是在 __init__.pyi 中存在,因为他们来自于通过C++编写的THNN库。

    下面通过分析源码来进行具体分析:

    1. torch.nn.ReLU()
      torch.nn 中的类代表的是神经网络层,这里我们看到作为类出现的 ReLU() 实际上只是调用了 torch.nn.functional 中的 relu relu_ 实现。
    class ReLU(Module):
        r"""Applies the rectified linear unit function element-wise:
    
        :math:`	ext{ReLU}(x)= max(0, x)`
    
        Args:
            inplace: can optionally do the operation in-place. Default: ``False``
    
        Shape:
            - Input: :math:`(N, *)` where `*` means, any number of additional
              dimensions
            - Output: :math:`(N, *)`, same shape as the input
    
        .. image:: scripts/activation_images/ReLU.png
    
        Examples::
    
            >>> m = nn.ReLU()
            >>> input = torch.randn(2)
            >>> output = m(input)
    
    
          An implementation of CReLU - https://arxiv.org/abs/1603.05201
    
            >>> m = nn.ReLU()
            >>> input = torch.randn(2).unsqueeze(0)
            >>> output = torch.cat((m(input),m(-input)))
        """
        __constants__ = ['inplace']
    
        def __init__(self, inplace=False):
            super(ReLU, self).__init__()
            self.inplace = inplace
    
        @weak_script_method
        def forward(self, input):
          # F 来自于 import nn.functional as F
            return F.relu(input, inplace=self.inplace)
    
        def extra_repr(self):
            inplace_str = 'inplace' if self.inplace else ''
            return inplace_str
    
    1. torch.nn.functional.relu() torch.nn.functional.relu_()
      其实这两个函数也是调用了 torch.relu() and torch.relu_()
    def relu(input, inplace=False):
        # type: (Tensor, bool) -> Tensor
        r"""relu(input, inplace=False) -> Tensor
    
        Applies the rectified linear unit function element-wise. See
        :class:`~torch.nn.ReLU` for more details.
        """
        if inplace:
            result = torch.relu_(input)
        else:
            result = torch.relu(input)
        return result
    
    
    relu_ = _add_docstr(torch.relu_, r"""
    relu_(input) -> Tensor
    
    In-place version of :func:`~relu`.
    """)
    

    至此我们对 RELU 函数在 torch 中的出现有了一个深入的认识。实际上作为基础的两个包,torch.nntorch.nn.functional 的关系是引用与包装的关系。

    torch.nn 与 torch.nn.functional 的区别与联系

    结合上述对 relu 的分析,我们能够更清晰的认识到两个库之间的联系。

    通常来说 torch.nn.functional 调用了 THNN库,实现核心计算,但是不对 learnable_parameters 例如 weight bias ,进行管理,为模型的使用带来不便。而 torch.nn 中实现的模型则对 torch.nn.functional,本质上是官方给出的对 torch.nn.functional的使用范例,我们通过直接调用这些范例能够快速方便的使用 pytorch ,但是范例可能不能够照顾到所有人的使用需求,因此保留 torch.nn.functional 来为这些用户提供灵活性,他们可以自己组装需要的模型。因此 pytorch 能够在灵活性与易用性上取得平衡。

    特别注意的是,torch.nn不全都是对torch.nn.functional的范例,有一些调用了来自其他库的函数,例如常用的RNN型神经网络族即没有在torch.nn.functional中出现。

    我们带着这样的思考再来看下一个例子作为结束:

    对于Linear请注意⚠️对比两个库下实现的不同:

    1. learnable parameters的管理
    2. 相互之间的调用关系
    3. 初始化过程
    class Linear(Module):
        r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
    
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            bias: If set to ``False``, the layer will not learn an additive bias.
                Default: ``True``
    
        Shape:
            - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
              additional dimensions and :math:`H_{in} = 	ext{in\_features}`
            - Output: :math:`(N, *, H_{out})` where all but the last dimension
              are the same shape as the input and :math:`H_{out} = 	ext{out\_features}`.
    
        Attributes:
            weight: the learnable weights of the module of shape
                :math:`(	ext{out\_features}, 	ext{in\_features})`. The values are
                initialized from :math:`mathcal{U}(-sqrt{k}, sqrt{k})`, where
                :math:`k = frac{1}{	ext{in\_features}}`
            bias:   the learnable bias of the module of shape :math:`(	ext{out\_features})`.
                    If :attr:`bias` is ``True``, the values are initialized from
                    :math:`mathcal{U}(-sqrt{k}, sqrt{k})` where
                    :math:`k = frac{1}{	ext{in\_features}}`
    
        Examples::
    
            >>> m = nn.Linear(20, 30)
            >>> input = torch.randn(128, 20)
            >>> output = m(input)
            >>> print(output.size())
            torch.Size([128, 30])
        """
        __constants__ = ['bias']
    
        def __init__(self, in_features, out_features, bias=True):
            super(Linear, self).__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = Parameter(torch.Tensor(out_features, in_features))
            if bias:
                self.bias = Parameter(torch.Tensor(out_features))
            else:
                self.register_parameter('bias', None)
            self.reset_parameters()
    
        def reset_parameters(self):
            init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            if self.bias is not None:
                fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
                bound = 1 / math.sqrt(fan_in)
                init.uniform_(self.bias, -bound, bound)
    
        @weak_script_method
        def forward(self, input):
            return F.linear(input, self.weight, self.bias)
    
        def extra_repr(self):
            return 'in_features={}, out_features={}, bias={}'.format(
                self.in_features, self.out_features, self.bias is not None
            )
    
    def linear(input, weight, bias=None):
        # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
        r"""
        Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
    
        Shape:
    
            - Input: :math:`(N, *, in\_features)` where `*` means any number of
              additional dimensions
            - Weight: :math:`(out\_features, in\_features)`
            - Bias: :math:`(out\_features)`
            - Output: :math:`(N, *, out\_features)`
        """
        if input.dim() == 2 and bias is not None:
            # fused op is marginally faster
            ret = torch.addmm(bias, input, weight.t())
        else:
            output = input.matmul(weight.t())
            if bias is not None:
                output += bias
            ret = output
        return ret
    
  • 相关阅读:
    安装git工具在ubuntu系统
    Ubuntu 16.04安装JDK并配置环境变量-【小白版】
    【gRPC使用问题4】
    【gRPC使用问题3】生成出来无法识别Google.Api.AnnotationsReflection.Descriptor
    LNMP
    Centos下安装Mysql
    yum方式安装的Apache目录详解和配置说明
    Centos下 yum方式安装LAMP
    CentOS配置网易163 yum源
    Apache主配置文件httpd.conf 详解
  • 原文地址:https://www.cnblogs.com/gjl-blog/p/11376673.html
Copyright © 2011-2022 走看看