zoukankan      html  css  js  c++  java
  • 深度学习之 GAN 进行 mnist 图片的生成

    深度学习之 GAN 进行 mnist 图片的生成

    mport numpy as np
    import os
    import codecs
    import torch
    from PIL import Image
    import PIL
    
    def get_int(b):
        return int(codecs.encode(b, 'hex'), 16)
    
    def extract_image(path, extract_path):
        with open(path, 'rb') as f:
            data = f.read()
            assert get_int(data[:4]) == 2051
            length = get_int(data[4:8])
            num_rows = get_int(data[8:12])
            num_cols = get_int(data[12:16])
            images = []
            parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
            parsed = parsed.reshape(length, num_rows, num_cols)
            
        for image_i, image in enumerate(parsed):
            Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i)))
            
    
    image_path = './mnist/t10k-images.idx3-ubyte'
    extract_path = './mnist/data/image'
    
    import math
    
    def images_square_grid(images, mode):
        save_size = math.floor(np.sqrt(images.shape[0]))
    
        # Scale to 0-255
        images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)
    
        # Put images in a square arrangement
        images_in_square = np.reshape(
                images[:save_size*save_size],
                (save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
        if mode == 'L':
            images_in_square = np.squeeze(images_in_square, 4)
    
        # Combine images to grid image
        new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
        for col_i, col_images in enumerate(images_in_square):
            for image_i, image in enumerate(col_images):
                im = Image.fromarray(image, mode)
                new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))
    
        return new_im
    
    def get_image(image_path, width, height, mode):
        
        image = Image.open(image_path)
        
        if image.size != (width, height):
            face_width = face_width = 108
            j = (image.size[0] - face_width) // 2
            i = (image.size[1] - face_height) // 2
    
            image = image.crop([j, i, j + face_width, i + face_height])
            image = image.resize([width, height], Image.BILINEAR)
        
        return np.array(image.convert(mode))
    
    def get_batch(image_files, width, height, mode):
        data_batch = np.array([get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32)
        
        if len(data_batch.shape) < 4:
            data_batch = data_batch.reshape(data_batch.shape + (1,))
        
        return data_batch
      
    %matplotlib inline
    import os
    from glob import glob
    from matplotlib import pyplot
    
    data_dir = './mnist/data'
    show_n_images = 25
    
    mnist_images = get_batch(glob(os.path.join(data_dir, 'image/*.jpg'))[:show_n_images], 28, 28, 'L')
    
    pyplot.imshow(images_square_grid(mnist_images, 'L'), cmap='gray')
    
    
    from torch.utils import data
    import torchvision as tv
    
    
    batch_size = 50
    
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(96),
        PIL.ImageOps.grayscale,
        tv.transforms.ToTensor()
    ])
    
    root="d:\work\yoho\dl\dl-study\chapter8\mnist\data"
    
    dataset = tv.datasets.ImageFolder(root, transform=transforms)
    dataloader = data.DataLoader(dataset, batch_size, shuffle=True, num_workers=1, drop_last=True)
    
    
    import torch.nn as nn
    import torch.optim as optim
    from torch.nn.modules import loss
    from torch.autograd import Variable as V
    
    class GNet(nn.Module):
        def __init__(self, opt):
            super(GNet, self).__init__()
            
            ngf = opt["ngf"]
            target = opt["target"] or 3
            
            self.main = nn.Sequential(
                nn.ConvTranspose2d( opt["nz"], ngf * 8, 4, 1, 0, bias=False),
                nn.BatchNorm2d(ngf * 8),
                nn.ReLU(True),
                
                nn.ConvTranspose2d( ngf * 8, ngf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 4),
                nn.ReLU(True),
                
                nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 2),
                nn.ReLU(True),
                
                nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf),
                nn.ReLU(True),
                
                nn.ConvTranspose2d( ngf, target, 5, 3, 1, bias=False),
                nn.Tanh()
            )
        
        def forward(self, input):
            return self.main(input)
        
    class DNet(nn.Module):
        def __init__(self, opt):
            super(DNet, self).__init__()
            
            ndf = opt["ndf"]
            input = opt["input"] or 3
            
            self.main = nn.Sequential(
                nn.Conv2d(input, ndf, 5, 3, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 2),
                nn.LeakyReLU(0.3, inplace=True),
                
                nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
                nn.Sigmoid()
            )
            
        def forward(self, input):
            return self.main(input).view(-1)
            
    
    lr_g = 0.01
    lr_d = 0.01
    ngf = 64
    ndf = 64
    raw_f = 1
    nz = 100
    d_every = 1
    g_every = 5
    
    net_g = GNet({"target": raw_f, "ngf": ngf, 'nz': nz})
    net_d = DNet({"input": raw_f, "ndf": ndf})
    
    opt_g = optim.Adam(net_g.parameters(), lr_g, betas=(0.5, 0.999))
    opt_d = optim.Adam(net_d.parameters(), lr_g, betas=(0.5, 0.999))
    
    criterion = torch.nn.BCELoss()
    
    true_labels = V(torch.ones(batch_size))
    fake_labels = V(torch.zeros(batch_size))
    fix_noises = V(torch.randn(batch_size, nz, 1, 1))
    noises = V(torch.randn(batch_size, nz, 1, 1))
    
    def train():
        for ii, (img, _) in enumerate(dataloader):
            real_img = V(img)
            
            if (ii + 1) % d_every == 0:
                opt_d.zero_grad()
                output = net_d(real_img)    
                loss_d = criterion(output, true_labels)    
                loss_d.backward()
    
                noises.data.copy_(torch.randn(batch_size, nz, 1, 1))
                
                fake_img = net_g(noises)
                
                fake_img = fake_img.detach()
                fake_output = net_d(fake_img) 
                loss_fake_d = criterion(fake_output, fake_labels)
                loss_fake_d.backward()
    
                opt_d.step()
    
    
            if (ii + 1) % g_every == 0:
                opt_g.zero_grad()
                noises.data.copy_(torch.randn(batch_size, nz, 1, 1))
                fake_image = net_g(noises)
    
                fake_output = net_d(fake_img)
    
                loss_g = criterion(fake_output, true_labels)
    
                loss_g.backward()
                opt_g.step()
    
    
    def print_image():
        fix_fake_imgs = net_g(fix_noises)
        fix_fake_imgs = fix_fake_imgs.data.view(batch_size, 96, 96, 1).numpy()
        pyplot.imshow(images_square_grid(fix_fake_imgs, 'L'), cmap='gray')
    
    
    epochs = 20
    def main():
        for i in range(epochs):
            print("epoch {}".format(i))
            train()
            
            if i % 2 == 0:
                print_image()
    main()
    

    注意 GAN 很慢,要使用 GPU来工作

  • 相关阅读:
    这是一棵树吗
    感染者
    es6 语法
    css3 flex 详解,可以实现div内容水平垂直居中
    移动端实现复制内容至剪贴板小例子
    jq+mui 阻止事件冒泡
    移动端H5 判断IOS还是Android 平台
    移动端布局 rem,和px
    关于H5移动端开发 iPhone X适配
    H5 微信公众号 监听返回事件
  • 原文地址:https://www.cnblogs.com/htoooth/p/8708086.html
Copyright © 2011-2022 走看看