from __future__ import print_function import os import matplotlib as mpl import tarfile import matplotlib.image as mpimg from matplotlib import pyplot as plt import mxnet as mx from mxnet import gluon from mxnet import ndarray as nd from mxnet.gluon import nn, utils from mxnet import autograd import numpy as np epochs = 2 # Set low by default for tests, set higher when you actually run this code. batch_size = 64 latent_z_size = 100 use_gpu = False ctx = mx.gpu() if use_gpu else mx.cpu() lr = 0.0002 beta1 = 0.5 data_path = 'lfw_dataset' with tarfile.open("lfw-deepfunneled.tgz") as tar: tar.extractall(path=data_path)
target_wd = 64 target_ht = 64 img_list = [] def transform(data, target_wd, target_ht): # resize to target_wd * target_ht data = mx.image.imresize(data, target_wd, target_ht) # transpose from (target_wd, target_ht, 3) # to (3, target_wd, target_ht) data = nd.transpose(data, (2,0,1)) # normalize to [-1, 1] data = data.astype(np.float32)/127.5 - 1 # if image is greyscale, repeat 3 times to get RGB image. if data.shape[0] == 1: data = nd.tile(data, (3, 1, 1)) return data.reshape((1,) + data.shape) for path, _, fnames in os.walk(data_path): for fname in fnames: if not fname.endswith('.jpg'): continue img = os.path.join(path, fname) img_arr = mx.image.imread(img) img_arr = transform(img_arr, target_wd, target_ht) img_list.append(img_arr) train_data = mx.io.NDArrayIter(data=nd.concatenate(img_list), batch_size=batch_size)
def visualize(img_arr): plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)) plt.axis('off') for i in range(4): plt.subplot(1,4,i+1) visualize(img_list[i + 10][0]) plt.show()
nc = 3 ngf = 64 netG = nn.Sequential() with netG.name_scope(): # input is Z, going into a convolution netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False)) netG.add(nn.BatchNorm()) netG.add(nn.Activation('relu')) # state size. (ngf*8) x 4 x 4 netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False)) netG.add(nn.BatchNorm()) netG.add(nn.Activation('relu')) # state size. (ngf*8) x 8 x 8 netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False)) netG.add(nn.BatchNorm()) netG.add(nn.Activation('relu')) # state size. (ngf*8) x 16 x 16 netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False)) netG.add(nn.BatchNorm()) netG.add(nn.Activation('relu')) # state size. (ngf*8) x 32 x 32 netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False)) netG.add(nn.Activation('tanh')) # state size. (nc) x 64 x 64 # build the discriminator ndf = 64 netD = nn.Sequential() with netD.name_scope(): # input is (nc) x 64 x 64 netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False)) netD.add(nn.LeakyReLU(0.2)) # state size. (ndf) x 32 x 32 netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False)) netD.add(nn.BatchNorm()) netD.add(nn.LeakyReLU(0.2)) # state size. (ndf) x 16 x 16 netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False)) netD.add(nn.BatchNorm()) netD.add(nn.LeakyReLU(0.2)) # state size. (ndf) x 8 x 8 netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False)) netD.add(nn.BatchNorm()) netD.add(nn.LeakyReLU(0.2)) # state size. (ndf) x 4 x 4 netD.add(nn.Conv2D(1, 4, 1, 0, use_bias=False))
# loss loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() # initialize the generator and the discriminator netG.initialize(mx.init.Normal(0.02), ctx=ctx) netD.initialize(mx.init.Normal(0.02), ctx=ctx) # trainer for the generator and the discriminator trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1}) trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
from datetime import datetime import time import logging real_label = nd.ones((batch_size,), ctx=ctx) fake_label = nd.zeros((batch_size,),ctx=ctx) def facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean() metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 for batch in train_data: ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### data = batch.data[0].as_in_context(ctx) latent_z = mx.nd.random_normal(0, 1, shape=(batch_size, latent_z_size, 1, 1), ctx=ctx) with autograd.record(): # train with real image output = netD(data).reshape((-1, 1)) errD_real = loss(output, real_label) metric.update([real_label,], [output,]) # train with fake image fake = netG(latent_z) output = netD(fake.detach()).reshape((-1, 1)) errD_fake = loss(output, fake_label) errD = errD_real + errD_fake errD.backward() metric.update([fake_label,], [output,]) trainerD.step(batch.data[0].shape[0]) ############################ # (2) Update G network: maximize log(D(G(z))) ########################### with autograd.record(): fake = netG(latent_z) output = netD(fake).reshape((-1, 1)) errG = loss(output, real_label) errG.backward() trainerG.step(batch.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() # logging.info(' binary training acc at epoch %d: %s=%f' % (epoch, name, acc)) # logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch # fake_img = fake[0] # visualize(fake_img) # plt.show()
num_image = 12 latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx) step = 0.05 for i in range(num_image): img = netG(latent_z) plt.subplot(3,4,i+1) visualize(img[0]) latent_z += 0.05 plt.show()