1. 正常情况下是1次forward 1次更新,代码为:
optimizer.zerograd
y = model(x)
loss_mse = torch.MSE(x, y)
loss_mse.backward()
optimizer.step()
其实只需要加3行代码
2. 当想要让batchsize强行变大时(多次forward 1次更新),代码为:
if i_count == 1:
optimizer.zerograd
y = model(x)
loss_mse = torch.MSE(x, y) %这几行不变,相当于只是加了下面的3行
loss_mse.backward
if i_ite % batchsize_loss == 0:
optimizer.step()
optimizer.zerograd