zoukankan      html  css  js  c++  java
  • 《深度学习框架PyTorch入门与实践》示例——AI插画师:生成对抗网络

    运行环境:Anaconda3 + PyCharm + PyTorch + python3

    这是《深度学习框架PyTorch入门与实践》的第七章示例,利用生成对抗网络生成动漫人物头像。

    作者陈云的GitHub:https://github.com/chenyuntc/pytorch-book/tree/master/chapter07-AnimeGAN 这里有实现该示例的代码。

    原书给的动漫人物头像数据集百度网盘链接https://pan.baidu.com/s/1eSifHcA 提取码:g5qa 失效了,不知道后续会不会补资源,先贴在这里。

    这篇博客讲解详细,并且有可实现的代码及数据集百度网盘资源,建议参考:https://www.cnblogs.com/wanghui-garcia/p/10785579.html

    如何安装visdom:https://zhuanlan.zhihu.com/p/138534069

    本篇作为学习笔记,在代码中给一些注释。

    model.py

    定义生成器:

    
    
    class NetG(nn.Module):
    """
    生成器定义
    """

    def __init__(self, opt):
    super(NetG, self).__init__()
    ngf = opt.ngf # 生成器feature map数

    self.main = nn.Sequential(
    # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
    nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
    # torch.nn.ConvTranspose2d(in_channels: int,
    # out_channels: int,
    # kernel_size: Union[T, Tuple[T, T]],
    # stride: Union[T, Tuple[T, T]] = 1,
    # padding: Union[T, Tuple[T, T]] = 0,
    # output_padding: Union[T, Tuple[T, T]] = 0,
    # groups: int = 1, bias: bool = True,
    # dilation: int = 1,
    # padding_mode: str = 'zeros')
    nn.BatchNorm2d(ngf * 8),
    nn.ReLU(True),
    # torch.nn.ReLU(inplace: bool = False)
    # ReLU输出为max(0, x)
    # 输出形状:(ngf * 8) * 4 * 4
    # 上一步的输出形状:(ngf*8) x 4 x 4

    nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
    nn.BatchNorm2d(ngf * 4),
    nn.ReLU(True),
    # 上一步的输出形状: (ngf*4) x 8 x 8

    nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
    nn.BatchNorm2d(ngf * 2),
    nn.ReLU(True),
    # 上一步的输出形状: (ngf*2) x 16 x 16

    nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
    nn.BatchNorm2d(ngf),
    nn.ReLU(True),
    # 上一步的输出形状:(ngf) x 32 x 32

    nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
    nn.Tanh() # 输出范围 -1~1 故而采用Tanh
    # 输出形状:3 x 96 x 96
    )

    def forward(self, input):
    return self.main(input)

    定义判别器:

    class NetD(nn.Module):
        """
        判别器定义
        """
    
        def __init__(self, opt):
            super(NetD, self).__init__()
            ndf = opt.ndf
            self.main = nn.Sequential(
                # 输入 3 x 96 x 96
                nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                # torch.nn.LeakyReLU(negative_slope: float = 0.01, inplace: bool = False)
                # LeakyReLU(x)=max(0,x)+negative_slope∗min(0,x)
                # 与ReLU相比,LeakyReLU的输出中小于0的部分为趋近于0的负数
                # 输出 (ndf) x 32 x 32
    
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 2),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf*2) x 16 x 16
    
                nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf*4) x 8 x 8
    
                nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf*8) x 4 x 4
    
                nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
                nn.Sigmoid()  # 输出一个数(概率)
            )
    
        def forward(self, input):
            return self.main(input).view(-1)

    main.py

    模型配置参数:

    # coding:utf8
    import os
    import ipdb
    import torch as t
    import torchvision as tv
    import tqdm
    from model import NetG, NetD
    from torchnet.meter import AverageValueMeter
    
    
    class Config(object):
        # data_path = 'data/'  # 数据集存放路径
        data_path = 'E:/pycharm projects/book1/chapter7/data/'
        # 存放图片的文件夹。注意,需要在data中建立一个文件夹,保存图片,才能使用dataloader
        num_workers = 4  # 多进程加载数据所用的进程数
        image_size = 96  # 图片尺寸
        batch_size = 256
        max_epoch = 1
        lr1 = 2e-4  # 生成器的学习率
        lr2 = 2e-4  # 判别器的学习率
        beta1 = 0.5  # Adam优化器的beta1参数
        gpu = True  # 是否使用GPU
        nz = 100  # 噪声维度
        ngf = 64  # 生成器feature map数
        ndf = 64  # 判别器feature map数
    
        # save_path = 'imgs/'  # 生成图片保存路径
        save_path = 'E:/pycharm projects/book1/chapter7/images'
    
        vis = True  # 是否使用visdom可视化
        env = 'GAN'  # visdom的env,在窗口工具栏中中选择该环境,显示训练图片结果
        plot_every = 20  # 每间隔20 batch,visdom画图一次
    
        debug_file = 'debug/debug.txt'  # 存在该文件则进入debug模式
        d_every = 1  # 每1个batch训练一次判别器
        g_every = 5  # 每5个batch训练一次生成器
        save_every = 10  # 每10个epoch保存一次模型
        netd_path = None  # 'checkpoints/netd_.pth' #预训练模型
        netg_path = None  # 'checkpoints/netg_211.pth'
    
        # 只测试不训练
        gen_img = 'result.png'
        # 从512张生成的图片中保存最好的64张
        gen_num = 64
        gen_search_num = 512
        gen_mean = 0  # 噪声的均值
        gen_std = 1  # 噪声的方差
    
    
    opt = Config()

    训练:

    def train(**kwargs):
        # 可以接收任意数量关键词参数的kwargs
        for k_, v_ in kwargs.items():
            setattr(opt, k_, v_)
            # setattr() 函数对应函数 getattr(),用于设置属性值,该属性不一定是存在的
            # setattr(object, name, value)
    
        device = t.device('cuda') if opt.gpu else t.device('cpu')
        if opt.vis:
            from visualize import Visualizer
            vis = Visualizer(opt.env)
    
        # 数据
        transforms = tv.transforms.Compose([
            tv.transforms.Resize(opt.image_size),
            # torchvision.transforms.Resize(size, interpolation=2) 重新定义尺寸
            tv.transforms.CenterCrop(opt.image_size),
            # torchvision.transforms.CenterCrop(size) 在中心裁剪给定的图像
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
        dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
        dataloader = t.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=opt.num_workers,
                                             drop_last=True
                                             )
    
        # 网络
        netg, netd = NetG(opt), NetD(opt)    # 将定义的参数输入网络中
        map_location = lambda storage, loc: storage     #将Tensor默认加载入内存中,待有需要时再移至显存中
        if opt.netd_path:
            netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
        if opt.netg_path:
            netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
        netd.to(device)
        netg.to(device)
    
        # 定义优化器和损失
        optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
        # torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
        optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
        criterion = t.nn.BCELoss().to(device)   # 计算二分类误差
    
        # 真图片label为1,假图片label为0
        # noises为生成网络的输入
        true_labels = t.ones(opt.batch_size).to(device)
        fake_labels = t.zeros(opt.batch_size).to(device)
        fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
        noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    
        errord_meter = AverageValueMeter()  # 测量并返回添加到其中的任何数字集合的平均值和标准偏差。 例如,测量一组示例中的平均损失很有用。
        errorg_meter = AverageValueMeter()
    
        epochs = range(opt.max_epoch)
        for epoch in iter(epochs):
            for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
                real_img = img.to(device)
    
                if ii % opt.d_every == 0:
                    # 训练判别器
                    optimizer_d.zero_grad()     # 梯度清零
                    # 尽可能的把真图片判别为正确
                    output = netd(real_img)
                    error_d_real = criterion(output, true_labels)
                    error_d_real.backward()     # 反向传播
    
                    # 尽可能把假图片判别为错误
                    noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                    fake_img = netg(noises).detach()  # 根据噪声生成假图
                    # detach()进行计算图截断,避免反向传播将梯度传到生成器中
                    output = netd(fake_img)
                    error_d_fake = criterion(output, fake_labels)
                    error_d_fake.backward()
                    optimizer_d.step()      # 更新参数
    
                    error_d = error_d_fake + error_d_real
    
                    errord_meter.add(error_d.item())
    
                if ii % opt.g_every == 0:
                    # 训练生成器
                    optimizer_g.zero_grad()
                    noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                    fake_img = netg(noises)
                    output = netd(fake_img)
                    error_g = criterion(output, true_labels)
                    error_g.backward()
                    optimizer_g.step()
                    errorg_meter.add(error_g.item())
    
                if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                    ## 可视化
                    import os
                    if os.path.exists(opt.debug_file):
                        ipdb.set_trace()
                    fix_fake_imgs = netg(fix_noises)
                    vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                    vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                    vis.plot('errord', errord_meter.value()[0])
                    vis.plot('errorg', errorg_meter.value()[0])
    
            if (epoch+1) % opt.save_every == 0:
                # 保存模型、图片
                tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                    range=(-1, 1))
                t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
                t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
                errord_meter.reset()
                errorg_meter.reset()

    测试(生成):

    @t.no_grad()    # 数据不需要计算梯度,也不会进行反向传播
    def generate(**kwargs):
        """
        随机生成动漫头像,并根据netd的分数选择较好的
        """
        for k_, v_ in kwargs.items():
            setattr(opt, k_, v_)
        
        device = t.device('cuda') if opt.gpu else t.device('cpu')
    
        netg, netd = NetG(opt).eval(), NetD(opt).eval()
        # Sets the module in evaluation mode 置为评价模式,与训练模式不同,有些参数不需要存储
        noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
        noises = noises.to(device)
    
        map_location = lambda storage, loc: storage
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
        netd.to(device)
        netg.to(device)
    
        # 生成图片,并计算图片在判别器的分数
        fake_img = netg(noises)
        scores = netd(fake_img).detach()   # 此处可以不用detach()?
    
        # 挑选最好的某几张
        indexs = scores.topk(opt.gen_num)[1]
        result = []
        for ii in indexs:
            result.append(fake_img.data[ii])
        # 保存图片
        tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))

    执行main函数:

    if __name__ == '__main__':
        import fire
        fire.Fire()
        train(gpu=True, vis=True, batch_size=256, max_epoch=200)
        # generate

    训练过程在visdom中显示,左边是生成的动漫头像,右边是训练样本,下面两个分别是判别和生成器的误差。

  • 相关阅读:
    k8s pod 定义 文件 yaml
    动态规划(1)
    Angular封装WangEditor富文本组件
    使用Angular CDK实现一个Service弹出Toast组件
    让Angular自定义组件支持form表单验证
    等了半年,m1芯片原生支持的.NET IDE出现了
    eachrt
    dgy
    开源节流
    company
  • 原文地址:https://www.cnblogs.com/huangliu1111/p/14108631.html
Copyright © 2011-2022 走看看