zoukankan      html  css  js  c++  java
  • Pytorch GAN训练时多次backward时出错问题

    转载自https://www.daimajiaoliu.com/daima/479755892900406

    和 https://oldpan.me/archives/pytorch-retain_graph-work

    从一个错误说起:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed

    在深度学习中,有些场景需要进行两次反向,比如Gan网络,需要对D进行一次,还要对G进行一次,很多人都会遇到上面这个错误,这个错误的意思就是尝试对一个计算图进行第二次反向,但是计算图已经释放了。其实看简单点和我们之前的backward一样,当图进行了一次梯度更新,就会把一些梯度的缓存给清空,为了避免下次叠加,但在Gan这种情形下,我们必须要二次更新,那怎么办呢。有两种方案:

    方案一:

    这是网上大多数给出的解决方案,在第一次反向时候加入一个 l2.backward(retain_graph=True) ,这样就能避免释放掉了。

    这个参数的作用是什么,官方定义为:

    retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

    大意是如果设置为False,计算图中的中间变量在计算完后就会被释放。但是在平时的使用中这个参数默认都为False从而提高效率,和creat_graph的值一样。

     也就相当于,假如你有两个Loss:

    # 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
    loss1.backward(retain_graph=True)
    loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
    optimizer.step() # 更新参数

    方案二:

    上面的方案虽然解决了问题,但是并不优美,因为我们用Gan的时候,D和G两者的更新并无联系,二者的联系仅仅是D里面用到了G的输出,而这个输出一般我们都是直接拿来用的,而问题就出现在这里。下面给一个模拟:

    data = torch.randn(4,10)
    
    model1 = torch.nn.Linear(10,2)
    model2 = torch.nn.Linear(2,2)
    
    optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001,betas=(0.5, 0.999))
    optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.001,betas=(0.5, 0.999))
    
    loss = torch.nn.CrossEntropyLoss()
    data = torch.randn(4,10)
    label = torch.Tensor([0,1,1,0]).long()
    for i in range(20):
        a = model1(data)
        b = model2(a)
        l1 = loss(a,label)
        l2 = loss(b,label)
        optimizer2.zero_grad()
        l2.backward()
        optimizer2.step()
    
        optimizer1.zero_grad()
        l1.backward()
        optimizer1.step()

    解决方案可以是l2.backward(retain_graph=True)。除此之外我们还可以是 b = model2(a.detach()) ,这个就优美一点,a.detach()和a的区别你可以打印出来看一下,其实a.detach()是没有梯度的,所以相当于一个单纯的数字,和model1就脱离了联系,这样model2和model1就是完全分离开来的两个图,但是如果用的是a则model2和model1则仍然公用一个图,所以导致了错误。可以看下面示意图(这个是我猜测,帮助理解):

    左边相当于直接用a而右边则用a.detach(),类似的在Gan网络里面D的输入可以改为G的输出y_fake.detach()。

    但有一点需要注意的是,两个网络一定没有需要共同更新的 ,假如上面的optimizer2 = torch.optim.Adam(itertools.chain(model1.parameters(),model2.parameters()), lr=0.001,betas=(0.5, 0.999)),则还是用retain_graph=True保险,因为.detach则model2反向不会传播到model1,导致不对model1里面参数更新。

    方案一可见:https://github.com/growvv/GAN-Pytorch/blob/93b49bd7ce395c2035df1d036daad83a67a9c691/Simple-GAN/simple_gan.py

    方案二可见:https://github.com/growvv/GAN-Pytorch/blob/257b267ea60af80212adc3dc5ad4cf28aeed00f6/CycleGAN/train.py


     

  • 相关阅读:
    利用matplotlib进行数据可视化
    《操作系统》课程笔记(Ch11-文件系统实现)
    《操作系统》课程笔记(Ch10-文件系统)
    《数据库原理》课程笔记 (Ch06-查询处理和优化)
    《计算机网络》课程笔记 (Ch05-网络层:控制平面)
    《计算机网络》课程笔记 (Ch04-网络层:数据平面)
    《计算机网络》课程笔记 (Ch03-运输层)
    东南大学《软件测试》课程复习笔记
    《数据库原理》课程笔记 (Ch05-数据库存储结构)
    《操作系统》课程笔记(Ch09-虚拟内存)
  • 原文地址:https://www.cnblogs.com/lfri/p/15556172.html
Copyright © 2011-2022 走看看