zoukankan      html  css  js  c++  java
  • Pytorch Transformer 中 Position Embedding 的实现

    Pytorch Transformer 中 Position Embedding 的实现

    The Positional Encoding part in Transformer is a special part, it isn't part of the network module, it is added in the embedded of words after embedding, so, If we save the model parameters, we will not save this part, or to say, this part don't have parameters in module, the output of this part is from calculation.

    Positional Encoding

    In paper, the positional encoding is added to the input embeddings at the bottoms of the encoder and decoder stacks. In Pytorch, we use a special function to get, register_buffer. In positional encoding part, we first use:

    self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
    

    and we can see the source code of register_buffer():

    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
            r"""Adds a buffer to the module.
    
            This is typically used to register a buffer that should not to be
            considered a model parameter. For example, BatchNorm's ``running_mean``
            is not a parameter, but is part of the module's state. Buffers, by
            default, are persistent and will be saved alongside parameters. This
            behavior can be changed by setting :attr:`persistent` to ``False``. The
            only difference between a persistent buffer and a non-persistent buffer
            is that the latter will not be a part of this module's
            :attr:`state_dict`.
    
            Buffers can be accessed as attributes using given names.
    
            Args:
                name (string): name of the buffer. The buffer can be accessed
                    from this module using the given name
                tensor (Tensor): buffer to be registered.
                persistent (bool): whether the buffer is part of this module's
                    :attr:`state_dict`.
    
            Example::
    
                >>> self.register_buffer('running_mean', torch.zeros(num_features))
    
            """
            if persistent is False and isinstance(self, torch.jit.ScriptModule):
                raise RuntimeError("ScriptModule does not support non-persistent buffers")
    
            if '_buffers' not in self.__dict__:
                raise AttributeError(
                    "cannot assign buffer before Module.__init__() call")
            elif not isinstance(name, torch._six.string_classes):
                raise TypeError("buffer name should be a string. "
                                "Got {}".format(torch.typename(name)))
            elif '.' in name:
                raise KeyError("buffer name can't contain "."")
            elif name == '':
                raise KeyError("buffer name can't be empty string """)
            elif hasattr(self, name) and name not in self._buffers:
                raise KeyError("attribute '{}' already exists".format(name))
            elif tensor is not None and not isinstance(tensor, torch.Tensor):
                raise TypeError("cannot assign '{}' object to buffer '{}' "
                                "(torch Tensor or None required)"
                                .format(torch.typename(tensor), name))
            else:
                self._buffers[name] = tensor
                if persistent:
                    self._non_persistent_buffers_set.discard(name)
                else:
                    self._non_persistent_buffers_set.add(name)
    

    this function is a special function in nn.Module in Pytorch, I think the most important is the _non_persistent_buffers_set attribute of nn.Module, and this is not be a part of this module's :attr:state_dict. So, when we want to save the model by torch.save this part will not be saved.

    Calculation

    In this Paper, it use sine and cosine functions of different frequencies:

    [PE_{pos,2_i} = sin(frac{pos}{10000^{2i/d_{model}}}) \ PE_{pos,2_{i+1}} = cos(frac{pos}{10000^{2i/d_{model}}}) ]

    Using Pytorch, we cloud use:

    def _get_sinusoid_encoding_table(self, n_position, d_hid):
            ''' Sinusoid position encoding table '''
            # TODO: make it with torch instead of numpy
    
            def get_position_angle_vec(position):
                # this part calculate the position In brackets
                return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
    
            sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
            # [:, 0::2] are all even subscripts, is dim_2i
            sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
            sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
    
            return torch.FloatTensor(sinusoid_table).unsqueeze(0)
    
    

    The last part is forward function:

    def forward(self, enc_output):
    	return enc_output + self.pos_table[:, :x.size(1)].clone().detach()
    

    we will add the positional encoding to the output of word embedding.

  • 相关阅读:
    C# 将Excel中的数据到DataSet中
    Struts select标签在 FreeMarker 中的使用。
    .Net 中显式实现接口
    C#抽象类和抽象方法的应用
    iframe 自动控制高
    兔子问题总结(总结)
    MySQL group_concat 方法的使用
    服务发现与负载均衡机制的实现
    ChannelFuture
    服务发布、订阅及远程通信
  • 原文地址:https://www.cnblogs.com/wevolf/p/15188846.html
Copyright © 2011-2022 走看看