zoukankan      html  css  js  c++  java
  • Pytorch的默认初始化分布 nn.Embedding.weight初始化分布

    一、nn.Embedding.weight初始化分布

     

    nn.Embedding.weight随机初始化方式是标准正态分布 [公式] ,即均值$mu=0$,方差$sigma=1$的正态分布。

     

    论据1——查看源代码

     

    ## class Embedding具体实现(在此只展示部分代码)
    import torch
    from torch.nn.parameter import Parameter
    
    from .module import Module
    from .. import functional as F
    
    class Embedding(Module):
        def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
                     max_norm=None, norm_type=2, scale_grad_by_freq=False,
                     sparse=False, _weight=None):
            if _weight is None:
                self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
                self.reset_parameters()
            else:
                assert list(_weight.shape) == [num_embeddings, embedding_dim], 
                    'Shape of weight does not match num_embeddings and embedding_dim'
                self.weight = Parameter(_weight)
    
        def reset_parameters(self):
            self.weight.data.normal_(0, 1)
            if self.padding_idx is not None:
                self.weight.data[self.padding_idx].fill_(0)
    

     

     

    Embedding这个类有个属性weight,它是torch.nn.parameter.Parameter类型的,作用就是存储真正的word embeddings。如果不给weight赋值,Embedding类会自动给他初始化,看上述代码第6~8行,如果属性weight没有手动赋值,则会定义一个torch.nn.parameter.Parameter对象,然后对该对象进行reset_parameters(),看第21行,对self.weight先转为Tensor在对其进行normal_(0, 1)(调整为$N(0, 1)$正态分布)。所以nn.Embeddig.weight默认初始化方式就是N(0, 1)分布,即均值$mu=0$,方差$sigma=1$的标准正态分布。

     

    论据2——简单验证nn.Embeddig.weight的分布

     

    下面将做的是验证nn.Embeddig.weight某一行词向量的均值和方差,以便验证是否为标准正态分布。
    注意:验证一行数字的均值为0,方差为1,显然不能说明该分布就是标准正态分布,只能是其必要条件,而不是充分条件,要想真正检测这行数字是不是正态分布,在概率论上有专门的较为复杂的方法,请查看概率论之假设检验。

     

    import torch.nn as nn
    
    # dim越大,均值、方差越接近0和1
    dim = 800000
    # 定义了一个(5, dim)的二维embdding
    # 对于NLP来说,相当于是5个词,每个词的词向量维数是dim
    # 每个词向量初始化为正态分布 N(0,1)(待验证)
    embd = nn.Embedding(5, dim)
    # type(embd.weight) is Parameter
    # type(embd.weight.data) is Tensor
    # embd.weight.data[0]是指(5, dim)的word embeddings中取第1个词的词向量,是dim维行向量
    weight = embd.weight.data[0].numpy()
    print("weight: {}".format(weight))
    
    weight_sum = 0
    for w in weight:
        weight_sum += w
    mean = weight_sum / dim
    print("均值: {}".format(mean))
    
    square_sum = 0
    for w in weight:
        square_sum += (mean - w) ** 2
    print("方差: {}".format(square_sum / dim))
    

     

     

    代码输出:

     

    weight: [-0.65507996  0.11627434 -1.6705967  ...  0.78397447  ...  -0.13477565]
    均值: 0.0006973597864689242
    方差: 1.0019535550544454
    

     

     

    可见,均值接近0,方差接近1,从这里也可以反映出nn.Embeddig.weight是标准正态分布$N(0, 1)$。

     

    二、torch.Tensortorch.tensortorch.randn初始化分布

     

    1、torch.rand

     

    返回$[0,1)$上的均匀分布(uniform distribution)。

     

    2、torch.randn

     

    返回$N(0, 1)$,即标准正态分布(standard normal distribution)。

     

    3、torch.Tensor

     

    torch.Tensor是Tensor class,torch.Tensor(2, 3)是调用Tensor的构造函数,构造了$2 imes3$矩阵,但是没有分配空间,未初始化。
    不推荐使用torch.Tensor创建Tensor,应使用torch.tenstortorch.onestorch.zerostorch.randtorch.randn等,原因:

     

    t = torch.Tensor(2,3)
    # 容易出现下述错误,因为t中的值取决当前内存中的随机值
    # 如果当前内存中随机值特别大会溢出
    RuntimeError: Overflow when unpacking long

    
    
  • 相关阅读:
    vmware导出为ovf
    华三接入交换机推荐
    mysql root情况
    ospf精确宣告地址
    kubernetes k8s yum localinstall
    js判断邮箱、用户名、手机号码和电话号码是否输入正确?
    如何修改Oracle中表的字段长度?
    mybatis与hibernate区别
    SSM框架的优势?
    SSH框架的优势?
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11269601.html
Copyright © 2011-2022 走看看