zoukankan      html  css  js  c++  java
  • pytorch中词向量生成的原理

    pytorch中的词向量的使用

    在pytorch我们使用nn.embedding进行词嵌入的工作。

    具体用法就是:

    import torch
    word_to_ix={'hello':0,'world':1}
    embeds = torch.nn.Embedding(2,5)
    hello_idx=torch.LongTensor([word_to_ix['hello']])
    hello_embed = embeds(hello_idx)
    print(hello_embed)
    print(embeds.weight)
    
    
    tensor([[ 0.6584,  0.2991, -1.2654,  0.9369,  0.6088]], grad_fn=<EmbeddingBackward>)
    
    Parameter containing:
    tensor([[ 0.6584,  0.2991, -1.2654,  0.9369,  0.6088],
            [ 0.1922,  1.5374,  0.5737, -0.8007, -0.4896]], requires_grad=True)
    
    

    在torch.nn.Embedding的源代码中,它是这么解释,
    This module is often used to store word embeddings and retrieve them using indices.
    The input to the module is a list of indices, and the output is the corresponding
    word embeddings.

    对于这个,我的理解是这样的torch.nn.Embedding 是一个矩阵类,当我传入参数之后,我可以得到一个矩阵对象,比如上面代码中的
    embeds = torch.nn.Embedding(2,5) 通过这个代码,我就获得了一个两行三列的矩阵对象embeds。这个时候,矩阵对象embeds的输入就是一个索引列表(当然这个列表
    应该是longtensor格式,得到的结果就是对应索引的词向量)

    我们这里有一点需要格外注意,在上面的结果中,有个这个东西 requires_grad=True

    我在开始接触pytorch的时候,对embedding的一个疑惑就是它是如何定义自动更新的。因为现在我们得到的这个词向量是随机初始化的结果,
    在后续神经网络反向传递过程中,这个参数是需要更新的。

    这里我想要点出一点来,就是词向量在这里是使用标准正态分布进行的初始化。我们可以通过查看源代码来进行验证。
    在源代码中

    if _weight is None:
                self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) ##定义一个Parameter对象
                self.reset_parameters() #随后对这个对象进行初始化
    ...
    ...
    
    def reset_parameters(self): #标准正态进行初始化
            init.normal_(self.weight)
            if self.padding_idx is not None:
                with torch.no_grad():
                    self.weight[self.padding_idx].fill_(0) 
    
  • 相关阅读:
    hdu 2147 kiki's game
    HDU 1846 Brave Game
    NYOJ 239 月老的难题
    NYOJ 170 网络的可靠性
    NYOJ 120 校园网络
    xtu字符串 B. Power Strings
    xtu字符串 A. Babelfish
    图论trainning-part-1 D. Going in Cycle!!
    XTU 二分图和网络流 练习题 J. Drainage Ditches
    XTU 二分图和网络流 练习题 B. Uncle Tom's Inherited Land*
  • 原文地址:https://www.cnblogs.com/lzida9223/p/10536177.html
Copyright © 2011-2022 走看看