checkpoint = torch.load('.pth')
net.load_state_dict(checkpoint['net'])
criterion_mse = torch.nn.MSELoss().to(cfg.device)
criterion_L1 = L1Loss()
optimizer = torch.optim.Adam([paras for paras in net.parameters() if paras.requires_grad == True], lr=cfg.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.n_steps, gamma=cfg.gamma)
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict= checkpoint['lr_schedule']
start_epoch = checkpoint['epoch']
for idx_epoch in range(start_epoch+1,80):
scheduler.step()
for idx_iter, () in enumerate(train_loader):
_ = net()
loss = criterion_mse(,)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if idx_epoch % 1 == 0:
checkpoint = {
"net": net.state_dict(),#网络参数
'optimizer': optimizer.state_dict(),#优化器
"epoch": idx_epoch,#训练轮数
'lr_schedule': scheduler.state_dict()#lr如何变化
}
torch.save(checkpoint,os.path.join(save_path, filename))
直接训练 a mean psnr: 28.160327919812364 a mean ssim: 0.8067064184409644 b mean psnr: 25.01364162100755 b mean ssim: 0.7600019779915981 c mean psnr: 25.83471135230011 c mean ssim: 0.7774989383731079 断点续训 a mean psnr: 28.15391601255439 a mean ssim: 0.8062857339309237 b mean psnr: 25.01115760689137 b mean ssim: 0.7596963993692107 c mean psnr: 25.842269038618145 c mean ssim: 0.7772710729947427
断点续训的效果基本和直接训练一致,但仍有些差别,后面会继续分析