zoukankan      html  css  js  c++  java
  • Pytorch半精度浮点型网络训练问题

    用Pytorch1.0进行半精度浮点型网络训练需要注意下问题:

    1、网络要在GPU上跑,模型和输入样本数据都要cuda().half()

    2、模型参数转换为half型,不必索引到每层,直接model.cuda().half()即可

    3、对于半精度模型,优化算法,Adam我在使用过程中,在某些参数的梯度为0的时候,更新权重后,梯度为零的权重变成了NAN,这非常奇怪,但是Adam算法对于全精度数据类型却没有这个问题。

      另外,SGD算法对于半精度和全精度计算均没有问题。

    还有一个问题是不知道是不是网络结构比较小的原因,使用半精度的训练速度还没有全精度快。这个值得后续进一步探索。

    对于上面的这个问题,的确是网络很小的情况下,在1080Ti上半精度浮点型没有很明显的优势,但是当网络变大之后,半精度浮点型要比全精度浮点型要快。但具体快多少和模型的大小以及输入样本大小有关系,我测试的是要快1/6,同时,半精度浮点型在占用内存上比较有优势,对于精度的影响尚未探究。

     将网络再变大些,epoch的次数也增大,半精度和全精度的时间差就表现出来了,在训练的时候。

  • 相关阅读:
    [LeetCode] Rotate Image
    [LeetCode] Generate Parentheses
    pandas 使用总结
    ConfigParser 读写配置文件
    Cheat Sheet pyspark RDD(PySpark 速查表)
    python随机生成字符
    grep 命令
    hadoop 日常使用记录
    python 2 计算字符串 余弦相似度
    screen命令
  • 原文地址:https://www.cnblogs.com/yanxingang/p/10148712.html
Copyright © 2011-2022 走看看