首先不妨假设最简单的一种情况:
假设$G$和$D$的损失函数:
那么计算梯度有:
第一种正确的方式:
import torch from torch import nn def set_requires_grad(net: nn.Module, mode=True): for p in net.parameters(): p.requires_grad_(mode) print(f"Pytorch version: {torch.__version__} ") X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False) G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) G.weight.data.fill_(0.5) G_optim = torch.optim.SGD(G.parameters(), lr=1.0) D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) D.weight.data.fill_(0.7) D_optim = torch.optim.SGD(D.parameters(), lr=1.0)print(f"Init grad: {G.weight.grad} {D.weight.grad}") print(f"Init weight: {G.weight.detach()} {D.weight.detach()} ") # Zero gradient of 2 layers. G_optim.zero_grad() D_optim.zero_grad() # Forward pass. Y = G(X) # Calculate D loss. D_loss = D(Y.detach()) ** 2 # Calculate G loss. G_loss = D(Y) ** 2 # Backward D loss. D_loss.backward(retain_graph=True) print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} ") # Backward G loss. set_requires_grad(D, False) # Turn off D's grad to avoid redundant gradient accumulation on D. G_loss.backward() set_requires_grad(D, True) print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} ") # Update G. G_optim.step() print(f"Checkpoint 3 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 3 weight: {G.weight.detach()} {D.weight.detach()} ") # Update D. D_optim.step() print(f"Checkpoint 4 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 4 weight: {G.weight.detach()} {D.weight.detach()} ")
运行结果:
Pytorch version: 1.9.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 2 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 3 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 3 weight: tensor([[[[0.0100]]]]) tensor([[[[0.7000]]]]) Checkpoint 4 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 4 weight: tensor([[[[0.0100]]]]) tensor([[[[0.3500]]]])
分析:
此时,$x = 1.0, y = 0.5, z = 0.7, heta_G = 0.5, heta_D = 0.7$,
首先checkpoint 1处,D loss的梯度反传到D网络上得到了 $2 y^2 cdot heta_D = 2 imes 0.25 imes 0.7 = 0.35$,没有反传到G网络。
其次checkpoint 2处,G loss的梯度反传,D网络梯度不受影响(因为所有网络参数的requires_grad := False),在G网络上得到了 $2 imes 0.5 imes 0.7^2 imes 1.0 = 0.49$。注意,这里的D网络参数 $ heta_D = 0.7$,因为尽管此时D loss已经反传,但是没有D optimizer的step()就还没有更新D网络。
最后checkpoint 3和4处,就是两个optimizer的step()分别更新G网络和D网络,这两个step()之间的先后顺序对最终的网络更新结果没什么影响。
注意,这种做法更新G网络时,对应的是更新前的D网络。
在Pytorch 1.2上的结果:
Pytorch version: 1.2.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 2 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 3 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 3 weight: tensor([[[[0.0100]]]]) tensor([[[[0.7000]]]]) Checkpoint 4 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 4 weight: tensor([[[[0.0100]]]]) tensor([[[[0.3500]]]])
可以看到,也是一样的。
错误的做法:
在 G_loss.backward() 前后不进行对D网络的网络参数的requires_grad的关和开,使得G loss反传了多余的梯度到D网络上。
第二种正确的方式:
import torch from torch import nn
print(f"Pytorch version: {torch.__version__} ") X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False) G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) G.weight.data.fill_(0.5) G_optim = torch.optim.SGD(G.parameters(), lr=1.0) D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) D.weight.data.fill_(0.7) D_optim = torch.optim.SGD(D.parameters(), lr=1.0) print(f"Init grad: {G.weight.grad} {D.weight.grad}") print(f"Init weight: {G.weight.detach()} {D.weight.detach()} ") # Forward pass. Y = G(X) # Zero gradient of D. D_optim.zero_grad() # Calculate D loss. D_loss = D(Y.detach()) ** 2 # Backward D loss. D_loss.backward() # Update D. D_optim.step() print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} ") # Zero gradient of G. G_optim.zero_grad() # Calculate G loss. G_loss = D(Y) ** 2 # Backward G loss. G_loss.backward() # Update G. G_optim.step() print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} ")
分析:
这种方式就很明了,更新D网络和更新G网络完全分开。
此时,$x = 1.0, y = 0.5, z = 0.7, heta_G = 0.5, heta_D = 0.7$,
首先checkpoint 1处,D loss的梯度反传到D网络上得到了 $2 y^2 cdot heta_D = 2 imes 0.25 imes 0.7 = 0.35$,没有反传到G网络。
其次checkpoint 2处,G loss的梯度同时反传到了G网络和D网络上,但是由于只更新G网络,D网络上的梯度会在下一个iteration中被zero_grad()清零。G网络上的梯度是 $2 imes 0.5 imes 0.35^2 imes 1.0 = 0.1225$,注意此时的D网络参数已经从 $0.7$ 更新为 $0.7 - 1.0 imes 0.35 = 0.35$(梯度下降:原参数减去学习率乘梯度,得新参数)。
注意,这种做法更新G网络时,对应的是已经更新后的D网络。事实上,我认为这种做法更正确,同时在逻辑上也更加清晰、更好理解。
运行结果:
Pytorch version: 1.9.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 grad: tensor([[[[0.1225]]]]) tensor([[[[0.5250]]]]) Checkpoint 2 weight: tensor([[[[0.3775]]]]) tensor([[[[0.3500]]]])
Pytorch version: 1.2.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 grad: tensor([[[[0.1225]]]]) tensor([[[[0.5250]]]]) Checkpoint 2 weight: tensor([[[[0.3775]]]]) tensor([[[[0.3500]]]])
一种错误的方式:
import torch from torch import nn print(f"Pytorch version: {torch.__version__} ") X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False) G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) G.weight.data.fill_(0.5) G_optim = torch.optim.SGD(G.parameters(), lr=1.0) D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) D.weight.data.fill_(0.7) D_optim = torch.optim.SGD(D.parameters(), lr=1.0) print(f"Init grad: {G.weight.grad} {D.weight.grad}") print(f"Init weight: {G.weight.detach()} {D.weight.detach()} ") # Forward pass. Y = G(X) # Zero gradient of G & D. G_optim.zero_grad() D_optim.zero_grad() # Calculate D loss. D_loss = D(Y.detach()) ** 2 # Calculate G loss. G_loss = D(Y) ** 2 # Backward D loss. D_loss.backward() # Update D. D_optim.step() print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} ") # Backward G loss. G_loss.backward() # Update G. G_optim.step() print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} ")
运行结果:
Pytorch version: 1.2.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 grad: tensor([[[[0.2450]]]]) tensor([[[[0.7000]]]]) Checkpoint 2 weight: tensor([[[[0.2550]]]]) tensor([[[[0.3500]]]])
Pytorch version: 1.9.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) Traceback (most recent call last): File "D:/Program/PycharmProjects/Test/test.py", line 65, in <module> G_loss.backward() File "D:ProgramAnacondaenvspy38_torch19libsite-packages orch\_tensor.py", line 255, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "D:ProgramAnacondaenvspy38_torch19libsite-packages orchautograd\__init__.py", line 147, in backward Variable._execution_engine.run_backward( RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 1, 1, 1]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
分析:
这种做法的错误核心就在于,你在计算G loss时(前向传播时)使用的是更新前的D网络,但是你在G loss反向传播时D网络已经变成了更新后的,
这个错误在较低版本(1.2.0)的Pytorch上并没有报错,我们可以看到它在计算G网络的梯度时,似乎用了更新前后的 $ heta_D$ 相乘 $2 imes 0.5 imes (0.7 imes 0.35) imes 1.0 = 0.245$,而非单纯更新前的 ${ heta_D}^2 = 0.7^2$,或者单纯更新后的 ${ heta_D}^2 = 0.35^2$.
至于在较高版本(1.9.0)的Pytorch上则直接报错了,估计是因为step()更新了D网络之后,G loss对应的计算图被破坏了,因此直接报了一个 "inplace operation" 错误。
因此,使用低版本的Pytorch时千万一定要注意这种比较隐蔽的错误写法!!!!