zoukankan      html  css  js  c++  java
  • 0607-参数初始化策略

    0607-参数初始化策略

    pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

    一、参数初始化策略概述

    深度学习中,一个好的参数初始化策略可以让模型更快地收敛,而一个差的参数初始化策略可能会让模型很难进行收敛,反复震荡甚至崩溃。

    nn.Module 中的参数一般都采取了比较合适的初始化策略,因此一般我们不需要考虑。不过我们也可以自定义一个参数初始化策略代替系统默认的,比如当我们使用 Parameter 时,由于 t.Tensor() 返回的是内存中的随机数,很可能会有极大值,这会时训练网络时造成溢出或者梯度小时,因此此时自定义一个参数的初始化策略尤为重要。

    torch 中的 nn.init 模块专门为初始化设计,实现了一些常用的初始化侧路了,而且就算如果某种初始化策略 nn.init 不提供,用户也可以自己直接初始化。

    二、利用 nn.init 初始化

    Glorot 正态分布初始化方法,也称作 Xavier 正态分布初始化,参数由 0 均值,标准差为 (sqrt{frac{2}{(fan_{in} + fan_{out}})}) 的正态分布产生,其中(fan_{in})(fan_{out}) 是分别权值张量的输入和输出元素数目。这种初始化同样是为了保证输入输出的方差不变,但是原论文中 [1] 是基于线性函数推导的,同时在 tanh 激活函数上有很好的效果,但不适用于ReLU激活函数。

    [std=gain×sqrt{frac{2}{fan_{in}+fan_{out}}} ]

    看不懂就别看了,我都没仔细看,百度 copy 来的。

    参考:[1] Understanding the difficulty of training deep feedforward neural networks — Glorot, X. & Bengio, Y. (2010)

    import torch as t
    from torch import nn
    from torch.nn import init
    linear = nn.Linear(3, 4)
    
    t.manual_seed(1)
    # 等价于 linear.weight.data.normal_(0, std)
    init.xavier_normal_(linear.weight)  #
    
    Parameter containing:
    tensor([[ 0.3535,  0.1427,  0.0330],
            [ 0.3321, -0.2416, -0.0888],
            [-0.8140,  0.2040, -0.5493],
            [-0.3010, -0.4769, -0.0311]], requires_grad=True)
    

    三、直接初始化

    import math
    t.manual_seed(1)
    
    # xavier初始化的计算公式
    std = math.sqrt(2) / math.sqrt(7.)
    linear.weight.data.normal_(0, std)
    
    tensor([[ 0.3535,  0.1427,  0.0330],
            [ 0.3321, -0.2416, -0.0888],
            [-0.8140,  0.2040, -0.5493],
            [-0.3010, -0.4769, -0.0311]])
    
    # 对模型的所有参数进行初始化
    for name, params in net.named_parameters():
        if name.find('linear') != -1:  # 对所有全连接层的参数进行初始化
            # init linear
            params[0]  # weight
            params[1]  # bias
        elif name.find('conv') != -1:
            pass
        elif name.find('norm') != -1:
            pass
    
    ---------------------------------------------------------------------------
    
    NameError                                 Traceback (most recent call last)
    
    <ipython-input-3-78d2673ab1d6> in <module>
          1 # 对模型的所有参数进行初始化
    ----> 2 for name, params in net.named_parameters():
          3     if name.find('linear') != -1:  # 对所有全连接层的参数进行初始化
          4         # init linear
          5         params[0]  # weight
    
    
    NameError: name 'net' is not defined
  • 相关阅读:
    [kuangbin带你飞]专题十六 KMP & 扩展KMP & ManacherK
    [kuangbin带你飞]专题十六 KMP & 扩展KMP & Manacher J
    [kuangbin带你飞]专题十六 KMP & 扩展KMP & Manacher I
    pat 1065 A+B and C (64bit)(20 分)(大数, Java)
    pat 1069 The Black Hole of Numbers(20 分)
    pat 1077 Kuchiguse(20 分) (字典树)
    pat 1084 Broken Keyboard(20 分)
    pat 1092 To Buy or Not to Buy(20 分)
    pat 1046 Shortest Distance(20 分) (线段树)
    pat 1042 Shuffling Machine(20 分)
  • 原文地址:https://www.cnblogs.com/nickchen121/p/14701862.html
Copyright © 2011-2022 走看看