zoukankan      html  css  js  c++  java
  • 【源码解读】cycleGAN(二) :训练

    源码地址:https://github.com/aitorzip/PyTorch-CycleGAN

    训练的代码见于train.py,首先定义好网络,两个生成器A2B, B2A和两个判别器A, B,以及对应的优化器(优化器的设置保证了只更新生成器或判别器,不会互相影响)

    ###### Definition of variables ######
    # Networks
    netG_A2B = Generator(opt.input_nc, opt.output_nc)
    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)
    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                    lr=opt.lr, betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))

    然后是数据

    # Dataset loader
    transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), 
                    transforms.RandomCrop(opt.size), 
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
    dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True), 
                            batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)

    接着就可以求取损失,反传梯度,更新网络,更新网络的时候首先更新生成器,然后分别更新两个判别器

    生成器:损失函数=身份损失+对抗损失+循环一致损失

    ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()
    
            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B)*5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A)*5.0
    
            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
    
            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
    
            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0
    
            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0
    
            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

    optimizer_G.step()

    判别器A  损失函数= 真实样本分类损失 + 虚假样本分类损失

    ###### Discriminator A ######
            optimizer_D_A.zero_grad()
    
            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)
    
            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
    
            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake)*0.5
            loss_D_A.backward()
    
            optimizer_D_A.step()
            ###################################

    判别器B 损失函数= 真实样本分类损失 + 虚假样本分类损失

    ###### Discriminator B ######
            optimizer_D_B.zero_grad()
    
            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)
            
            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
    
            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake)*0.5
            loss_D_B.backward()
    
            optimizer_D_B.step()
            ###################################

    可以注意到,判别器损失中,虚假样本fake_A,fake_B都采用detach()操作,脱离计算图,这样判别器的损失进行反向传播不会对整个网络计算梯度,避免了不必要的计算

  • 相关阅读:
    MySQL 5.7 在windows下修改max_allowed_packet变量
    linux(redhat)安装jdk1.8
    linux安装jdk1.8之后报错Error: dl failure on line 893的解决办法
    Java中List Set Map 是否有序等总结
    面试题吐槽系列之一
    面试算法——快速排序
    推荐系统——online(上)
    推荐系统架构
    零基础开始推荐系统
    视觉显著性优秀硕士论文总结
  • 原文地址:https://www.cnblogs.com/wzyuan/p/11893348.html
Copyright © 2011-2022 走看看