zoukankan      html  css  js  c++  java
  • 关于pytorch下GAN loss的backward和step等注意事项

    首先不妨假设最简单的一种情况:

    假设$G$和$D$的损失函数:

    那么计算梯度有:

     

    第一种正确的方式:

    import torch
    from torch import nn
    
    
    def set_requires_grad(net: nn.Module, mode=True):
        for p in net.parameters():
            p.requires_grad_(mode)
    
    
    print(f"Pytorch version: {torch.__version__} 
    ")
    
    X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False)
    
    G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False)
    G.weight.data.fill_(0.5)
    G_optim = torch.optim.SGD(G.parameters(), lr=1.0)
    
    D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False)
    D.weight.data.fill_(0.7)
    D_optim = torch.optim.SGD(D.parameters(), lr=1.0)print(f"Init grad: {G.weight.grad} {D.weight.grad}")
    print(f"Init weight: {G.weight.detach()} {D.weight.detach()} 
    ")
    
    # Zero gradient of 2 layers.
    G_optim.zero_grad()
    D_optim.zero_grad()
    
    # Forward pass.
    Y = G(X)
    
    # Calculate D loss.
    D_loss = D(Y.detach()) ** 2
    
    # Calculate G loss.
    G_loss = D(Y) ** 2
    
    # Backward D loss.
    D_loss.backward(retain_graph=True)
    
    print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}")
    print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} 
    ")
    
    # Backward G loss.
    set_requires_grad(D, False)  # Turn off D's grad to avoid redundant gradient accumulation on D.
    G_loss.backward()
    set_requires_grad(D, True)
    
    print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}")
    print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} 
    ")
    
    # Update G.
    G_optim.step()
    
    print(f"Checkpoint 3 grad: {G.weight.grad} {D.weight.grad}")
    print(f"Checkpoint 3 weight: {G.weight.detach()} {D.weight.detach()} 
    ")
    
    # Update D.
    D_optim.step()
    
    print(f"Checkpoint 4 grad: {G.weight.grad} {D.weight.grad}")
    print(f"Checkpoint 4 weight: {G.weight.detach()} {D.weight.detach()} 
    ")

    运行结果:

    Pytorch version: 1.9.0 
    
    Init grad: None None
    Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 1 grad: None tensor([[[[0.3500]]]])
    Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 2 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
    Checkpoint 2 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 3 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
    Checkpoint 3 weight: tensor([[[[0.0100]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 4 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
    Checkpoint 4 weight: tensor([[[[0.0100]]]]) tensor([[[[0.3500]]]]) 

    分析:

    此时,$x = 1.0, y = 0.5, z = 0.7, heta_G = 0.5, heta_D = 0.7$,

    首先checkpoint 1处,D loss的梯度反传到D网络上得到了 $2 y^2 cdot heta_D = 2 imes 0.25 imes 0.7 = 0.35$,没有反传到G网络。

    其次checkpoint 2处,G loss的梯度反传,D网络梯度不受影响(因为所有网络参数的requires_grad := False),在G网络上得到了 $2 imes 0.5 imes 0.7^2 imes 1.0 = 0.49$。注意,这里的D网络参数 $ heta_D = 0.7$,因为尽管此时D loss已经反传,但是没有D optimizer的step()就还没有更新D网络。

    最后checkpoint 3和4处,就是两个optimizer的step()分别更新G网络和D网络,这两个step()之间的先后顺序对最终的网络更新结果没什么影响。

    注意,这种做法更新G网络时,对应的是更新前的D网络。

    在Pytorch 1.2上的结果:

    Pytorch version: 1.2.0 
    
    Init grad: None None
    Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 1 grad: None tensor([[[[0.3500]]]])
    Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 2 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
    Checkpoint 2 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 3 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
    Checkpoint 3 weight: tensor([[[[0.0100]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 4 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
    Checkpoint 4 weight: tensor([[[[0.0100]]]]) tensor([[[[0.3500]]]]) 

    可以看到,也是一样的。

    错误的做法:

    在 G_loss.backward() 前后不进行对D网络的网络参数的requires_grad的关和开,使得G loss反传了多余的梯度到D网络上。


    第二种正确的方式:

    import torch
    from torch import nn


    print(f"Pytorch version: {torch.__version__} ") X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False) G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) G.weight.data.fill_(0.5) G_optim = torch.optim.SGD(G.parameters(), lr=1.0) D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) D.weight.data.fill_(0.7) D_optim = torch.optim.SGD(D.parameters(), lr=1.0) print(f"Init grad: {G.weight.grad} {D.weight.grad}") print(f"Init weight: {G.weight.detach()} {D.weight.detach()} ") # Forward pass. Y = G(X) # Zero gradient of D. D_optim.zero_grad() # Calculate D loss. D_loss = D(Y.detach()) ** 2 # Backward D loss. D_loss.backward() # Update D. D_optim.step() print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} ") # Zero gradient of G. G_optim.zero_grad() # Calculate G loss. G_loss = D(Y) ** 2 # Backward G loss. G_loss.backward() # Update G. G_optim.step() print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} ")

    分析:

    这种方式就很明了,更新D网络和更新G网络完全分开。

    此时,$x = 1.0, y = 0.5, z = 0.7, heta_G = 0.5, heta_D = 0.7$,

    首先checkpoint 1处,D loss的梯度反传到D网络上得到了 $2 y^2 cdot heta_D = 2 imes 0.25 imes 0.7 = 0.35$,没有反传到G网络。

    其次checkpoint 2处,G loss的梯度同时反传到了G网络和D网络上,但是由于只更新G网络,D网络上的梯度会在下一个iteration中被zero_grad()清零。G网络上的梯度是 $2 imes 0.5 imes 0.35^2 imes 1.0 = 0.1225$,注意此时的D网络参数已经从 $0.7$ 更新为 $0.7 - 1.0 imes 0.35 = 0.35$(梯度下降:原参数减去学习率乘梯度,得新参数)。

    注意,这种做法更新G网络时,对应的是已经更新后的D网络。事实上,我认为这种做法更正确,同时在逻辑上也更加清晰、更好理解。

    运行结果:

    Pytorch version: 1.9.0 
    
    Init grad: None None
    Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 1 grad: None tensor([[[[0.3500]]]])
    Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) 
    
    Checkpoint 2 grad: tensor([[[[0.1225]]]]) tensor([[[[0.5250]]]])
    Checkpoint 2 weight: tensor([[[[0.3775]]]]) tensor([[[[0.3500]]]]) 
    Pytorch version: 1.2.0 
    
    Init grad: None None
    Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 1 grad: None tensor([[[[0.3500]]]])
    Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) 
    
    Checkpoint 2 grad: tensor([[[[0.1225]]]]) tensor([[[[0.5250]]]])
    Checkpoint 2 weight: tensor([[[[0.3775]]]]) tensor([[[[0.3500]]]]) 

    一种错误的方式:

    import torch
    from torch import nn
    
    
    print(f"Pytorch version: {torch.__version__} 
    ")
    
    X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False)
    
    G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False)
    G.weight.data.fill_(0.5)
    G_optim = torch.optim.SGD(G.parameters(), lr=1.0)
    
    D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False)
    D.weight.data.fill_(0.7)
    D_optim = torch.optim.SGD(D.parameters(), lr=1.0)
    
    print(f"Init grad: {G.weight.grad} {D.weight.grad}")
    print(f"Init weight: {G.weight.detach()} {D.weight.detach()} 
    ")
    
    # Forward pass.
    Y = G(X)
    
    # Zero gradient of G & D.
    G_optim.zero_grad()
    D_optim.zero_grad()
    
    # Calculate D loss.
    D_loss = D(Y.detach()) ** 2
    
    # Calculate G loss.
    G_loss = D(Y) ** 2
    
    # Backward D loss.
    D_loss.backward()
    
    # Update D.
    D_optim.step()
    
    print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}")
    print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} 
    ")
    
    # Backward G loss.
    G_loss.backward()
    
    # Update G.
    G_optim.step()
    
    print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}")
    print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} 
    ")

    运行结果:

    Pytorch version: 1.2.0 
    
    Init grad: None None
    Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 1 grad: None tensor([[[[0.3500]]]])
    Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) 
    
    Checkpoint 2 grad: tensor([[[[0.2450]]]]) tensor([[[[0.7000]]]])
    Checkpoint 2 weight: tensor([[[[0.2550]]]]) tensor([[[[0.3500]]]]) 
    Pytorch version: 1.9.0 
    
    Init grad: None None
    Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 
    
    Checkpoint 1 grad: None tensor([[[[0.3500]]]])
    Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) 
    
    Traceback (most recent call last):
      File "D:/Program/PycharmProjects/Test/test.py", line 65, in <module>
        G_loss.backward()
      File "D:ProgramAnacondaenvspy38_torch19libsite-packages	orch\_tensor.py", line 255, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
      File "D:ProgramAnacondaenvspy38_torch19libsite-packages	orchautograd\__init__.py", line 147, in backward
        Variable._execution_engine.run_backward(
    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 1, 1, 1]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

    分析:

    这种做法的错误核心就在于,你在计算G loss时(前向传播时)使用的是更新前的D网络,但是你在G loss反向传播时D网络已经变成了更新后的,

    这个错误在较低版本(1.2.0)的Pytorch上并没有报错,我们可以看到它在计算G网络的梯度时,似乎用了更新前后的 $ heta_D$ 相乘 $2 imes 0.5 imes (0.7 imes 0.35) imes 1.0 = 0.245$,而非单纯更新前的 ${ heta_D}^2 = 0.7^2$,或者单纯更新后的 ${ heta_D}^2 = 0.35^2$.

    至于在较高版本(1.9.0)的Pytorch上则直接报错了,估计是因为step()更新了D网络之后,G loss对应的计算图被破坏了,因此直接报了一个 "inplace operation" 错误。

    因此,使用低版本的Pytorch时千万一定要注意这种比较隐蔽的错误写法!!!!

    转载请注明出处:https://dilthey.cnblogs.com/
  • 相关阅读:
    Laravel5.1学习笔记15 数据库1 数据库使用入门
    Laravel5.1学习笔记i14 系统架构6 Facade
    Laravel5.1学习笔记13 系统架构5 Contract
    Laravel5.1学习笔记12 系统架构4 服务容器
    Laravel5.1学习笔记11 系统架构3 服务提供者
    JavaScript之“创意时钟”项目
    JQuery轮播图
    SQL Server之增删改操作
    jQuery之基本选择器Practice
    JQuery---选择器、DOM节点操作练习
  • 原文地址:https://www.cnblogs.com/dilthey/p/15116941.html
Copyright © 2011-2022 走看看