zoukankan      html  css  js  c++  java
  • pytorch显存越来越多的一个自己没注意的原因

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss

    参考:https://blog.csdn.net/qq_27292549/article/details/80250031

    我和博主犯了一毛一样的低级错误。。。。

    下面是原博解释:

    运行着就发现显存炸了

    观察了一下发现随着每个batch显存消耗在不断增大..

    参考了别人的代码发现那句loss一般是这样写 

    loss_sum += loss.data[0]
    

    这是因为输出的loss的数据类型是Variable。

    而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。

    如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大~那么消耗的显存也就越来越大~~

    总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算...

    补充:

    用Tensor计算也是有坑的,要写成:

     train_loss += loss.item()
    

    不然显存还是会炸。。。。。

  • 相关阅读:
    js计算两个时间相差天数
    享元模式
    外观模式
    组合模式
    装饰者模式
    桥接模式
    适配器模式
    元素量词 ? + *
    linux安装使用7zip
    linux shell使用别名,切换当前目录
  • 原文地址:https://www.cnblogs.com/Charlene-HRI/p/10234656.html
Copyright © 2011-2022 走看看