模型用的是苹果转橘子的数据集。但可能是由于模型太大且图片数量不足(1000张左右)。因此,有些图片transform不是很好。
模型是挂在天池上面跑的。还需要导入until.py文件,我放在文末了。
import glob import random import os import torch from torch.utils.data import Dataset from PIL import Image import utils import torchvision.transforms as transforms from torch.autograd import Variable from PIL import Image import matplotlib.pyplot as plt %matplotlib inline import torchvision.utils as vutils import numpy as np import torch.nn as nn import torch.nn.functional as F import itertools import torchvision
定义一些超参
""" gpu """ gpu_id = [0] utils.cuda_devices(gpu_id) # 决定我们在哪个设备上运行 device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") """ param """ epochs = 2500 batch_size = 50 size=64 lr = 0.0002 n_critic = 5 z_dim = 100
导入数据集
class ImageDataset(Dataset): def __init__(self, root, transforms_=None, unaligned=False, mode='train'): self.transform = transforms.Compose(transforms_) # 将几个变化整合在一起 self.unaligned = unaligned # 匹配 `数据集文件夹/(train or test)/(A or B)` 下的所有文件并打乱 self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*')) self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*')) def __getitem__(self, index): # `__getitem__`, 允许用户像字典一样访问数据 : X[key] -> value item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)])) if self.unaligned: # 不对齐则随机出一张图片 item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])) else: item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)])) return {'A': item_A, 'B': item_B} def __len__(self): # 两者中取一张取数量大的 return max(len(self.files_A), len(self.files_B))
# Dataset loader transforms_ = [transforms.Resize(int(size*1.12), Image.BICUBIC), transforms.RandomCrop(size), transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.ToTensor(), # PIL.Image/np.ndarray (HWC) [0, 255] -> torch.FloatTensor (CHW) [0.0, 1.0] transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] # 将三个通道 `Normalize` dataloader = torch.utils.data.DataLoader(ImageDataset(r'dataset/apple2orange', transforms_=transforms_, unaligned=True), batch_size=batch_size, shuffle=True)
# 展示一些训练图片 real_batch = next(iter(dataloader))['B'] plt.figure(figsize=(8,8)) plt.axis("off") plt.title("Training Images") plt.imshow(np.transpose(vutils.make_grid(real_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
定义模型
class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() conv_block = [ nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), # 进行原地操作, 节省内存 nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features) ] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x)
class Generator(nn.Module): def __init__(self, input_nc, output_nc, n_residual_blocks=2): super(Generator, self).__init__() # Initial convolution block model = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), nn.InstanceNorm2d(64), nn.ReLU(inplace=True) ] # Downsampling in_features = 64 out_features = in_features*2 for _ in range(2): model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features*2 # Residual blocks for _ in range(n_residual_blocks): model += [ResidualBlock(in_features)] # Upsampling out_features = in_features//2 for _ in range(2): model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features//2 # Output layer model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh() ] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x)
class Discriminator(nn.Module): def __init__(self, input_nc): super(Discriminator, self).__init__() # A bunch of convolutions one after another model = [nn.Conv2d(input_nc, 64, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True) ] model += [nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, inplace=True) ] model += [nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, inplace=True) ] model += [nn.Conv2d(256, 512, 4, padding=1), nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, inplace=True) ] # FCN classification layer model += [nn.Conv2d(512, 1, 4, padding=1)] self.model = nn.Sequential(*model) def forward(self, x): x = self.model(x) # Globel average pooling and flatten return F.avg_pool2d(x, x.shape[2:]).view(x.shape[0], -1)
实例化模型
netG_A2B = Generator(3, 3) netG_B2A = Generator(3, 3) netD_A = Discriminator(3) netD_B = Discriminator(3) criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() utils.cuda([netG_A2B, netG_B2A, netD_A, netD_B, criterion_GAN, criterion_cycle, criterion_identity])
# Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), # `itertools.chain` 相当于把两个参数结合在一起了 lr=lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))
每次训练的语句停了,都要重新运行这句,把保存的最新模型变成本次运行的模型
""" load checkpoint """ ckpt_dir = './checkpoints1/celeba_cyclegan' utils.mkdir(ckpt_dir) try: ckpt = utils.load_checkpoint(ckpt_dir) start_epoch = ckpt['epoch'] netD_A.load_state_dict(ckpt['netD_A']) netD_B.load_state_dict(ckpt['netD_B']) netG_A2B.load_state_dict(ckpt['netG_A2B']) netG_B2A.load_state_dict(ckpt['netG_B2A']) optimizer_G.load_state_dict(ckpt['optimizer_G']) optimizer_D_A.load_state_dict(ckpt['optimizer_D_A']) optimizer_D_B.load_state_dict(ckpt['optimizer_D_B']) except: print(' [*] No checkpoint!') start_epoch = 0
class ReplayBuffer(): def __init__(self, max_size=50): assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.' self.max_size = max_size self.data = [] def push_and_pop(self, data): to_return = [] for element in data.data: element = torch.unsqueeze(element, 0) # 在指定位置添加一个维度 if len(self.data) < self.max_size: self.data.append(element) to_return.append(element) else: if random.uniform(0,1) > 0.5: i = random.randint(0, self.max_size-1) to_return.append(self.data[i].clone()) # torch.Tensor.clone 相当于 .copy self.data[i] = element else: to_return.append(element) return Variable(torch.cat(to_return))
定义一些用到的变量
# Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor input_A = Tensor(batch_size, 3, size, size) input_B = Tensor(batch_size, 3, size, size) target_real = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False) target_fake = Variable(Tensor(batch_size).fill_(0.0), requires_grad=False) fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer()
A = [] # 用来显示图片的
B = []
for epoch in range(start_epoch, epochs): for i, batch in enumerate(dataloader): if i == len(dataloader) - 1: continue # Set model input (X, 3, H, W) real_A = Variable(input_A.copy_(batch['A'])) real_B = Variable(input_B.copy_(batch['B'])) real_A, real_B, target_real, target_fake = utils.cuda([real_A, real_B, target_real, target_fake]) #-------- Generators A2B and B2A -------- optimizer_G.zero_grad() # Identity loss # G_A2B(B) should equal B if real B is fed same_B = netG_A2B(real_B) loss_identity_B = criterion_identity(same_B, real_B)*5.0 # 0 维变量 # G_B2A(A) should equal A if real A is fed same_A = netG_B2A(real_A) loss_identity_A = criterion_identity(same_A, real_A)*5.0 # 0 维变量 # GAN loss fake_B = netG_A2B(real_A) pred_fake = netD_B(fake_B) # 此处有 `UserWarning` : [1], [1, 1] 不匹配, 但是不影响操作 loss_GAN_A2B = criterion_GAN(pred_fake, target_real) # 0 维变量 fake_A = netG_B2A(real_B) pred_fake = netD_A(fake_A) # 此处有 `UserWarning` : [1], [1, 1] 不匹配, 但是不影响操作 loss_GAN_B2A = criterion_GAN(pred_fake, target_real) # 0 维变量 # Cycle loss recovered_A = netG_B2A(fake_B) loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0 # 0 维变量 recovered_B = netG_A2B(fake_A) loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0 # 0 维变量 # Total loss loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB loss_G.backward() optimizer_G.step() #-------- Discriminator A -------- optimizer_D_A.zero_grad() # Real loss pred_real = netD_A(real_A) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_A = fake_A_buffer.push_and_pop(fake_A) fake_A = utils.cuda(fake_A) pred_fake = netD_A(fake_A.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_A = (loss_D_real + loss_D_fake)*0.5 loss_D_A.backward() optimizer_D_A.step() #-------- Discriminator B -------- optimizer_D_B.zero_grad() # Real loss pred_real = netD_B(real_B) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_B = fake_B_buffer.push_and_pop(fake_B) fake_B = utils.cuda(fake_B) pred_fake = netD_B(fake_B.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_B = (loss_D_real + loss_D_fake)*0.5 loss_D_B.backward() optimizer_D_B.step() ################################### if (i + 1) % 15 == 0: print("Epoch: (%3d) (%5d/%5d)" % (epoch, i + 1, len(dataloader))) if (epoch + 1) % 5 == 0: # 因为我训练了近2000次,所以我每5个epoch存一次图片 save_dir = './sample_images_while_training/cycleGAN' utils.mkdir(save_dir) # torchvision.utils.save_image(real_A, '%s/Epoch_(%d)_(%dof%d)_real_A.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10) # torchvision.utils.save_image(real_B, '%s/Epoch_(%d)_(%dof%d)_real_B.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10) torchvision.utils.save_image(fake_A, '%s/Epoch_(%d)_(%dof%d)_fake_A.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10) torchvision.utils.save_image(fake_B, '%s/Epoch_(%d)_(%dof%d)_fake_B.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10) with torch.no_grad(): A.append(vutils.make_grid(fake_A.detach().cpu(), padding=2, normalize=True)) B.append(vutils.make_grid(fake_B.detach().cpu(), padding=2, normalize=True)) utils.save_checkpoint({'epoch': epoch + 1, 'netD_A': netD_A.state_dict(), 'netD_B': netD_B.state_dict(), 'netG_A2B': netG_A2B.state_dict(), 'netG_B2A': netG_B2A.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D_A': optimizer_D_A.state_dict(), 'optimizer_D_B': optimizer_D_B.state_dict(),}, '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch + 1), max_keep=2)
显示训练图片
# 画出真实图像 plt.figure(figsize=(15,15)) plt.subplot(1,2,1) plt.axis("off") plt.title("A") plt.imshow(np.transpose(A[1], (1,2,0))) # 画出来自最后一次训练的假图像 plt.subplot(1,2,2) plt.axis("off") plt.title("B") plt.imshow(np.transpose(B[1],(1,2,0))) plt.show()
untils.py文件,其中定义了转cuda,保存模型,调用模型等函数
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import shutil import torch def mkdir(paths): if not isinstance(paths, (list, tuple)): paths = [paths] for path in paths: if not os.path.isdir(path): os.makedirs(path) def cuda_devices(gpu_ids): gpu_ids = [str(i) for i in gpu_ids] os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(gpu_ids) def cuda(xs): if torch.cuda.is_available(): if not isinstance(xs, (list, tuple)): return xs.cuda() else: return [x.cuda() for x in xs] def save_checkpoint(state, save_path, is_best=False, max_keep=None): # save checkpoint torch.save(state, save_path) # deal with max_keep save_dir = os.path.dirname(save_path) list_path = os.path.join(save_dir, 'latest_checkpoint') save_path = os.path.basename(save_path) if os.path.exists(list_path): with open(list_path) as f: ckpt_list = f.readlines() ckpt_list = [save_path + ' '] + ckpt_list else: ckpt_list = [save_path + ' '] if max_keep is not None: for ckpt in ckpt_list[max_keep:]: ckpt = os.path.join(save_dir, ckpt[:-1]) if os.path.exists(ckpt): os.remove(ckpt) ckpt_list[max_keep:] = [] with open(list_path, 'w') as f: f.writelines(ckpt_list) # copy best if is_best: shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt')) def load_checkpoint(ckpt_dir_or_file, map_location=None, load_best=False): if os.path.isdir(ckpt_dir_or_file): if load_best: ckpt_path = os.path.join(ckpt_dir_or_file, 'best_model.ckpt') else: with open(os.path.join(ckpt_dir_or_file, 'latest_checkpoint')) as f: ckpt_path = os.path.join(ckpt_dir_or_file, f.readline()[:-1]) else: ckpt_path = ckpt_dir_or_file ckpt = torch.load(ckpt_path, map_location=map_location) print(' [*] Loading checkpoint from %s succeed!' % ckpt_path) return ckpt