zoukankan      html  css  js  c++  java
  • 0607-参数初始化策略

    0607-参数初始化策略

    pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

    一、参数初始化策略概述

    深度学习中,一个好的参数初始化策略可以让模型更快地收敛,而一个差的参数初始化策略可能会让模型很难进行收敛,反复震荡甚至崩溃。

    nn.Module 中的参数一般都采取了比较合适的初始化策略,因此一般我们不需要考虑。不过我们也可以自定义一个参数初始化策略代替系统默认的,比如当我们使用 Parameter 时,由于 t.Tensor() 返回的是内存中的随机数,很可能会有极大值,这会时训练网络时造成溢出或者梯度小时,因此此时自定义一个参数的初始化策略尤为重要。

    torch 中的 nn.init 模块专门为初始化设计,实现了一些常用的初始化侧路了,而且就算如果某种初始化策略 nn.init 不提供,用户也可以自己直接初始化。

    二、利用 nn.init 初始化

    Glorot 正态分布初始化方法,也称作 Xavier 正态分布初始化,参数由 0 均值,标准差为 (sqrt{frac{2}{(fan_{in} + fan_{out}})}) 的正态分布产生,其中(fan_{in})(fan_{out}) 是分别权值张量的输入和输出元素数目。这种初始化同样是为了保证输入输出的方差不变,但是原论文中 [1] 是基于线性函数推导的,同时在 tanh 激活函数上有很好的效果,但不适用于ReLU激活函数。

    [std=gain×sqrt{frac{2}{fan_{in}+fan_{out}}} ]

    看不懂就别看了,我都没仔细看,百度 copy 来的。

    参考:[1] Understanding the difficulty of training deep feedforward neural networks — Glorot, X. & Bengio, Y. (2010)

    import torch as t
    from torch import nn
    from torch.nn import init
    linear = nn.Linear(3, 4)
    
    t.manual_seed(1)
    # 等价于 linear.weight.data.normal_(0, std)
    init.xavier_normal_(linear.weight)  #
    
    Parameter containing:
    tensor([[ 0.3535,  0.1427,  0.0330],
            [ 0.3321, -0.2416, -0.0888],
            [-0.8140,  0.2040, -0.5493],
            [-0.3010, -0.4769, -0.0311]], requires_grad=True)
    

    三、直接初始化

    import math
    t.manual_seed(1)
    
    # xavier初始化的计算公式
    std = math.sqrt(2) / math.sqrt(7.)
    linear.weight.data.normal_(0, std)
    
    tensor([[ 0.3535,  0.1427,  0.0330],
            [ 0.3321, -0.2416, -0.0888],
            [-0.8140,  0.2040, -0.5493],
            [-0.3010, -0.4769, -0.0311]])
    
    # 对模型的所有参数进行初始化
    for name, params in net.named_parameters():
        if name.find('linear') != -1:  # 对所有全连接层的参数进行初始化
            # init linear
            params[0]  # weight
            params[1]  # bias
        elif name.find('conv') != -1:
            pass
        elif name.find('norm') != -1:
            pass
    
    ---------------------------------------------------------------------------
    
    NameError                                 Traceback (most recent call last)
    
    <ipython-input-3-78d2673ab1d6> in <module>
          1 # 对模型的所有参数进行初始化
    ----> 2 for name, params in net.named_parameters():
          3     if name.find('linear') != -1:  # 对所有全连接层的参数进行初始化
          4         # init linear
          5         params[0]  # weight
    
    
    NameError: name 'net' is not defined
  • 相关阅读:
    调试JavaScript/VB Script脚本程序(ASP篇)
    成功接收来自Internet的邮件必须要做到的条件
    (转)Ext与.NET超完美整合 .NET开发者的超级优势
    如何防垃圾邮件用你的邮件服务器转发
    记录书籍名称
    GRE网站
    JAVA线程的缺陷
    【让这些电影给你“治病”】
    zoj题目分类
    Oracle to_char格式化函数
  • 原文地址:https://www.cnblogs.com/nickchen121/p/14701862.html
Copyright © 2011-2022 走看看