zoukankan      html  css  js  c++  java
  • pytorch中的nn.Embedding

    直接看代码:

    import torch
    import torch.nn as nn
    embedding=nn.Embedding(10,3)
    input=torch.LongTensor([[1,2,4,5],[4,3,2,9]])
    embedding(input)
    tensor([[[ 0.8052, -0.1044, -0.6971],
             [ 1.3792, -0.1265, -1.1444],
             [ 1.4152, -0.1551, -1.2433],
             [ 0.7060, -1.0585,  0.5130]],
    
            [[ 1.4152, -0.1551, -1.2433],
             [-0.9881, -0.1601,  0.6339],
             [ 1.3792, -0.1265, -1.1444],
             [-1.1703,  1.8496,  0.8113]]], grad_fn=<EmbeddingBackward>)

    第一个参数是字的总数,第二个参数是字的向量表示的维度。

    我们的输入input是两个句子,每个句子都是由四个字组成的,使用每个字的索引来表示,于是使用nn.Embedding对输入进行编码,每个字都会编码成长度为3的向量。

    再看看下个例子:

    embedding = nn.Embedding(10, 3, padding_idx=0)
    input=torch.LongTensor([[0,2,0,5]])
    embedding(input)
    tensor([[[ 0.0000,  0.0000,  0.0000],
             [ 0.0829,  1.4141,  0.0277],
             [ 0.0000,  0.0000,  0.0000],
             [ 0.1337, -1.1472,  0.2182]]], grad_fn=<EmbeddingBackward>)

    transformer中的字的编码就可以这么表示:

    class Embeddings(nn.Module):
      def __init__(self,d_model,vocab):
        #d_model=512, vocab=当前语言的词表大小
        super(Embeddings,self).__init__()
        self.lut=nn.Embedding(vocab,d_model) 
        # one-hot转词嵌入,这里有一个待训练的矩阵E,大小是vocab*d_model
        self.d_model=d_model # 512
      def forward(self,x): 
         # x ~ (batch.size, sequence.length, one-hot), 
         #one-hot大小=vocab,当前语言的词表大小
         return self.lut(x)*math.sqrt(self.d_model) 
         # 得到的10*512词嵌入矩阵,主动乘以sqrt(512)=22.6,
         #这里我做了一些对比,感觉这个乘以sqrt(512)没啥用… 求反驳。
         #这里的输出的tensor大小类似于(batch.size, sequence.length, 512)

    参考:

    https://zhuanlan.zhihu.com/p/107889011

    https://blog.csdn.net/qq_38883844/article/details/104331382

  • 相关阅读:
    JAVA动态代理学习
    .Netcore Swagger
    无废话,用.net core mvc 开发一个虽小但五脏俱全的网站
    专为开发者开发的导航网站
    利用webbrowser自动查取地点坐标
    帮你理解学习lambda式
    activeX 打包
    activeX 开发
    提取验证码到winform上webbroswer和axwebbroswer
    存储过程分页的注入问题以及解决
  • 原文地址:https://www.cnblogs.com/xiximayou/p/13343608.html
Copyright © 2011-2022 走看看