zoukankan      html  css  js  c++  java
  • gluon实现DCGAN(深度卷积对抗生成网络)

    from __future__ import print_function
    import os
    import matplotlib as mpl
    import tarfile
    import matplotlib.image as mpimg
    from matplotlib import pyplot as plt
    
    import mxnet as mx
    from mxnet import gluon
    from mxnet import ndarray as nd
    from mxnet.gluon import nn, utils
    from mxnet import autograd
    import numpy as np
    
    
    epochs = 2 # Set low by default for tests, set higher when you actually run this code.
    batch_size = 64
    latent_z_size = 100
    
    use_gpu = False
    ctx = mx.gpu() if use_gpu else mx.cpu()
    
    lr = 0.0002
    beta1 = 0.5
    data_path = 'lfw_dataset'
    with tarfile.open("lfw-deepfunneled.tgz") as tar:
        tar.extractall(path=data_path)
    target_wd = 64
    target_ht = 64
    img_list = []
    
    def transform(data, target_wd, target_ht):
        # resize to target_wd * target_ht
        data = mx.image.imresize(data, target_wd, target_ht)
        # transpose from (target_wd, target_ht, 3)
        # to (3, target_wd, target_ht)
        data = nd.transpose(data, (2,0,1))
        # normalize to [-1, 1]
        data = data.astype(np.float32)/127.5 - 1
        # if image is greyscale, repeat 3 times to get RGB image.
        if data.shape[0] == 1:
            data = nd.tile(data, (3, 1, 1))
        return data.reshape((1,) + data.shape)
    
    for path, _, fnames in os.walk(data_path):
        for fname in fnames:
            if not fname.endswith('.jpg'):
                continue
            img = os.path.join(path, fname)
            img_arr = mx.image.imread(img)
            img_arr = transform(img_arr, target_wd, target_ht)
            img_list.append(img_arr)
    train_data = mx.io.NDArrayIter(data=nd.concatenate(img_list), batch_size=batch_size)
    def visualize(img_arr):
        plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
        plt.axis('off')
    
    for i in range(4):
        plt.subplot(1,4,i+1)
        visualize(img_list[i + 10][0])
    plt.show()

    nc = 3
    ngf = 64
    netG = nn.Sequential()
    with netG.name_scope():
        # input is Z, going into a convolution
        netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
        netG.add(nn.BatchNorm())
        netG.add(nn.Activation('relu'))
        # state size. (ngf*8) x 4 x 4
        netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
        netG.add(nn.BatchNorm())
        netG.add(nn.Activation('relu'))
        # state size. (ngf*8) x 8 x 8
        netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
        netG.add(nn.BatchNorm())
        netG.add(nn.Activation('relu'))
        # state size. (ngf*8) x 16 x 16
        netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
        netG.add(nn.BatchNorm())
        netG.add(nn.Activation('relu'))
        # state size. (ngf*8) x 32 x 32
        netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
        netG.add(nn.Activation('tanh'))
        # state size. (nc) x 64 x 64
    
    # build the discriminator
    ndf = 64
    netD = nn.Sequential()
    with netD.name_scope():
        # input is (nc) x 64 x 64
        netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
        netD.add(nn.LeakyReLU(0.2))
        # state size. (ndf) x 32 x 32
        netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
        netD.add(nn.BatchNorm())
        netD.add(nn.LeakyReLU(0.2))
        # state size. (ndf) x 16 x 16
        netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
        netD.add(nn.BatchNorm())
        netD.add(nn.LeakyReLU(0.2))
        # state size. (ndf) x 8 x 8
        netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
        netD.add(nn.BatchNorm())
        netD.add(nn.LeakyReLU(0.2))
        # state size. (ndf) x 4 x 4
        netD.add(nn.Conv2D(1, 4, 1, 0, use_bias=False))
    # loss
    loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    
    # initialize the generator and the discriminator
    netG.initialize(mx.init.Normal(0.02), ctx=ctx)
    netD.initialize(mx.init.Normal(0.02), ctx=ctx)
    
    # trainer for the generator and the discriminator
    trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    from datetime import datetime
    import time
    import logging
    
    real_label = nd.ones((batch_size,), ctx=ctx)
    fake_label = nd.zeros((batch_size,),ctx=ctx)
    
    def facc(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()
    metric = mx.metric.CustomMetric(facc)
    
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    
    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            data = batch.data[0].as_in_context(ctx)
            latent_z = mx.nd.random_normal(0, 1, shape=(batch_size, latent_z_size, 1, 1), ctx=ctx)
    
            with autograd.record():
                # train with real image
                output = netD(data).reshape((-1, 1))
                errD_real = loss(output, real_label)
                metric.update([real_label,], [output,])
    
                # train with fake image
                fake = netG(latent_z)
                output = netD(fake.detach()).reshape((-1, 1))
                errD_fake = loss(output, fake_label)
                errD = errD_real + errD_fake
                errD.backward()
                metric.update([fake_label,], [output,])
    
            trainerD.step(batch.data[0].shape[0])
    
            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            with autograd.record():
                fake = netG(latent_z)
                output = netD(fake).reshape((-1, 1))
                errG = loss(output, real_label)
                errG.backward()
    
            trainerG.step(batch.data[0].shape[0])
    
            # Print log infomation every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errD).asscalar(),
                           nd.mean(errG).asscalar(), acc, iter, epoch))
            iter = iter + 1
            btic = time.time()
    
        name, acc = metric.get()
        metric.reset()
        # logging.info('
    binary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        # logging.info('time: %f' % (time.time() - tic))
    
        # Visualize one generated image for each epoch
        # fake_img = fake[0]
        # visualize(fake_img)
        # plt.show()

    num_image = 12
    latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)
    step = 0.05
    for i in range(num_image):
        img = netG(latent_z)
        plt.subplot(3,4,i+1)
        visualize(img[0])
        latent_z += 0.05
    plt.show()

  • 相关阅读:
    safari调试iphone
    git 本地仓库关联远程仓库
    video 自动播放及循环播放问题
    webpack4系列之【3. webpack4优化记录】
    展示博客
    第三天冲刺
    第二天冲刺
    第一天冲刺
    UML设计
    Alpha项目冲刺
  • 原文地址:https://www.cnblogs.com/hxjbc/p/8311241.html
Copyright © 2011-2022 走看看