zoukankan      html  css  js  c++  java
  • pytorch---初始化

    pytorch---初始化

    在深度学习中参数的初始化十分重要,良好的初始化能让模型更快收敛,并达到更高水平,而糟糕的初始化则可能使得模型迅速瘫痪。PyTorch中nn.Module的模块参数都采取了较为合理的初始化策略,因此一般不用我们考虑,当然我们也可以用自定义初始化去代替系统的默认初始化。而当我们在使用Parameter时,自定义初始化则尤为重要,因t.Tensor()返回的是内存中的随机数,很可能会有极大值,这在实际训练网络中会造成溢出或者梯度消失。PyTorch中nn.init模块就是专门为初始化而设计,如果某种初始化策略nn.init不提供,用户也可以自己直接初始化。

    # 利用nn.init初始化
    from torch.nn import init
    linear = nn.Linear(3, 4)

    t.manual_seed(1)
    # 等价于 linear.weight.data.normal_(0, std)
    init.xavier_normal_(linear.weight)

    # 直接初始化
    import math
    t.manual_seed(1)

    # xavier初始化的计算公式
    std = math.sqrt(2)/math.sqrt(7.)
    linear.weight.data.normal_(0,std)
    # 对模型的所有参数进行初始化
    for name, params in net.named_parameters():
      if name.find('linear') != -1:
          # init linear
          params[0] # weight
          params[1] # bias
      elif name.find('conv') != -1:
          pass
      elif name.find('norm') != -1:
          pass

    补充

    xavier初始化

    torch.nn.init.xavier_uniform(tensor, gain=1)

    对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。

    初始化服从均匀分布U(−a,a)U(−a,a),其中a=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√×3–√a=gain×2/(fan_in+fan_out)×3,该初始化方法也称Glorot initialisation。

    参数:

          tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

          a:可选择的缩放参数

    例如:

    w = torch.Tensor(3, 5)
    nn.init.xavier_uniform(w, gain=nn.init.calculate_gain('relu'))

    torch.nn.init.xavier_normal(tensor, gain=1)

    对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。初始化服从高斯分布N(0,std)N(0,std),其中std=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√std=gain×2/(fan_in+fan_out),该初始化方法也称Glorot initialisation。

    参数:

          tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

          a:可选择的缩放参数

    例如:


    w = torch.Tensor(3, 5)
    nn.init.xavier_normal(w)

    另外在torch.Tensor下还定义了一些in-place的函数:

  • 相关阅读:
    python-进程池实例
    python-进程通过队列模拟数据的下载
    python-多进程模板
    python-多线程同步中创建互斥锁解决资源竞争的问题
    CentOS6.5配置网络
    解决CentOS系统Yum出现"Cannot find a valid baseurl for repo"问题
    CentOS 6.5安装图形界面
    Centos安装git
    Web前端优化,提高加载速度
    谁说写代码的不懂生活
  • 原文地址:https://www.cnblogs.com/zmmz/p/9847876.html
Copyright © 2011-2022 走看看