zoukankan      html  css  js  c++  java
  • tensorflow和pytorch中的参数初始化调用方法

      神经网络中最重要的就是参数了,其中包括权重项$W$和偏置项$b$。 我们训练神经网络的最终目的就是得到最好的参数,使得目标函数取得最小值。参数的初始化也同样重要,因此微调受到很多人的重视,

      只列一些常用的!

    Tensorflow

    常数初始化

    tf.constant_initializer(value)

    value取0,则代表的是全0初始化,也可以表示为 tf.zeros_initializer() 

    value取1,则代表的是全1初始化,也可以表示为 tf.ones_initializer() 

    随机均匀初始化器

    tf.random_uniform_initializer(minval=0, maxval=None) 

    不需要指定最小值和最大值的均匀初始化: 

    tf.uniform_unit_scaling_initializer(factor=1.0) 

    随机正态初始化器

    (均值为0,方差为1)

    tf.random_normal_initializer(mean=0.0, stddev=1.0) 

    截断正态分布初始化器

    (均值为0,方差为1)

    tf.truncated_normal_initializer(mean=0.0, stddev=1.0) 

    正交矩阵初始化器

    tf.orthogonal_initializer() 

      生成正交矩阵的随机数。当需要生成的参数是2维时,这个正交矩阵是由均匀分布的随机数矩阵经过SVD分解而来。

    Xavier uniform 初始化器

    tf.glorot_uniform_initializer() 

      初始化为与输入输出节点数相关的均匀分布随机数,和xavier_initializer()是一个东西

      假设均匀分布的区间是[-limit, limit],则

    $$limit=sqrt{frac{6}{fan_in + fan_out}}$$

    其中的fan_in和fan_out分别表示输入单元的结点数和输出单元的结点数。

    Xavier normal 初始化器

    tf.glorot_normal_initializer() 

      初始化为与输入输出节点数相关的截断正太分布随机数

    $$stddev = sqrt{frac{2}{fan\_in + fan\_out}}$$

    其中的fan_in和fan_out分别表示输入单元的结点数和输出单元的结点数。

    变尺度正态、均匀分布

    tf.variance_scaling_initializer(scale=1.0,mode="fan_in", distribution="truncated_normal")
    • scale: 缩放尺度
    • mode: 有3个值可选,分别是 “fan_in”, “fan_out” 和 “fan_avg”,用于控制计算标准差 stddev的值
    • distribution: 2个值可选,”normal”或“uniform”,定义生成的tensor的分布是截断正太分布还是均匀分布

    distribution选‘normal’的时候,生成的是截断正太分布,标准差 stddev = sqrt(scale / n), n的取值根据mode的不同设置而不同:

    • mode = "fan_in", n为输入单元的结点数;         
    • mode = "fan_out",n为输出单元的结点数;
    • mode = "fan_avg",n为输入和输出单元结点数的平均值;

    distribution选 ‘uniform’,生成均匀分布的随机数tensor,最大值 max_value和 最小值 min_value 的计算公式:

    • max_value = sqrt(3 * scale / n)
    • min_value = -max_value

    he初始化

      如果使用relu激活函数,最好使用He初始化,因为在ReLU网络中,假定每一层有一半的神经元被激活,另一半为0,所有要保持variance不变。

    tf.contrib.layers.variance_scaling_initializer()

    Xavier初始化

    如果激活函数用sigmoid和tanh,最好用xavier初始化器,

    Xavier初始化的基本思想是保持输入和输出的方差一致,这样就避免了所有输出值都趋向于0.

    from tensorflow.contrib.layers import xavier_initializer

    pytorch

    PyTorch 中参数的默认初始化在各个层的 reset_parameters() 方法中。例如:nn.Linear 和 nn.Conv2D,都是在 [-limit, limit] 之间的均匀分布(Uniform distribution),其中 limit 是$frac{1}{sqrt{fan\_in}}$ ,fan_in是指参数张量(tensor)的输入单元的数量

    下面是几种常见的初始化方式

    常数初始化

    nn.init.constant_(w, 0.3)

    均匀分布

    nn.init.uniform_(w)

    正态分布

    nn.init.normal_(w, mean=0, std=1)

    xavier_uniform 初始化

      Xavier初始化的基本思想是保持输入和输出的方差一致,这样就避免了所有输出值都趋向于0。这是通用的方法,适用于任何激活函数。

    # 默认方法
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(m.weight)

     也可以使用 gain 参数来自定义初始化的标准差来匹配特定的激活函数:

    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(m.weight(), gain=nn.init.calculate_gain('relu'))

    xavier_normal 初始化

    nn.init.xavier_normal_(w)

    kaiming_uniform 初始化

    nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

    kaiming_normal 初始化

      He initialization的思想是:在ReLU网络中,假定每一层有一半的神经元被激活,另一半为0。推荐在ReLU网络中使用。

    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_in')

    正交初始化(Orthogonal Initialization)

      主要用以解决深度网络下的梯度消失、梯度爆炸问题,在RNN中经常使用的参数初始化方法。

    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.orthogonal(m.weight)

    Batchnorm Initialization

      在非线性激活函数之前,我们想让输出值有比较好的分布(例如高斯分布),以便于计算梯度和更新参数。Batch Normalization 将输出值强行做一次 Gaussian Normalization 和线性变换:

    for m in model:
        if isinstance(m, nn.BatchNorm2d):
            nn.init.constant(m.weight, 1)
            nn.init.constant(m.bias, 0)

    参考

    tensorflow 学习笔记(九)- 参数初始化(initializer)

    pytorch中的参数初始化方法总结

  • 相关阅读:
    吴裕雄 Bootstrap 前端框架开发——Bootstrap 字体图标(Glyphicons)
    Logical partitioning and virtualization in a heterogeneous architecture
    十条实用的jQuery代码片段
    十条实用的jQuery代码片段
    十条实用的jQuery代码片段
    C#比较dynamic和Dictionary性能
    C#比较dynamic和Dictionary性能
    C#比较dynamic和Dictionary性能
    分别使用 XHR、jQuery 和 Fetch 实现 AJAX
    分别使用 XHR、jQuery 和 Fetch 实现 AJAX
  • 原文地址:https://www.cnblogs.com/LXP-Never/p/13189621.html
Copyright © 2011-2022 走看看