zoukankan      html  css  js  c++  java
  • 使用Keras编写GAN的入门

    使用Keras编写GAN的入门

    Time: 2017-5-31


    前言

    主要参考了网页[1]的教程,同时主要算法来自Ian J. Goodfellow 的论文,算法如下:

    gan

    gan

    代码

    %matplotlib inline
    import numpy as np
    import pandas as pd
    
    from keras.models import Model
    from keras.layers import Dense, Activation, Input, Reshape
    from keras.layers import Conv1D, Flatten, Dropout
    from keras.optimizers import SGD, Adam
    
    
    from tqdm import tqdm_notebook as tqdm  # 进度条
    
    
    # 生成随机正弦曲线的数据
    def sample_data(n_samples=10000, x_vals=np.arange(0, 5, .1), max_offset=1000, mul_range=[1, 2]):
        vectors = []
        for i in range(n_samples):
            offset = np.random.random() * max_offset
            mul = mul_range[0] + np.random.random() * (mul_range[1] - mul_range[0])
            vectors.append(np.sin(offset + x_vals * mul) / 2 + .5)
            
        return np.array(vectors)
        
    # 创建生成模型
    def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3):
        x = Dense(dense_dim)(G_in)
        x = Activation('tanh')(x)
        G_out = Dense(out_dim, activation='tanh')(x)
        G = Model(G_in, G_out)
        opt = SGD(lr=lr)
        
        G.compile(loss='binary_crossentropy', optimizer=opt)
        
        return G, G_out
        
    # 创建判别模型
    def get_discriminative(D_in, lr=1e-3, drate = .25, n_channels=50, conv_sz=5, leak=.2):
        x = Reshape((-1, 1))(D_in)
        x = Conv1D(n_channels, conv_sz, activation='relu')(x)
        x = Dropout(drate)(x)
        x = Flatten()(x)
        x = Dense(n_channels)(x)
        D_out = Dense(2, activation='sigmoid')(x)
        D = Model(D_in, D_out)
        dopt = Adam(lr=lr)
        D.compile(loss='binary_crossentropy', optimizer=dopt)
        
        return D, D_out
    
        
        
    def set_trainability(model, trainable=False):
        model.trainable = trainable
        for layer in model.layers:
            layer.trainable = trainable
            
    def make_gan(GAN_in, G, D):
        set_trainability(D, False)
        x = G(GAN_in)
        GAN_out = D(x)
        GAN = Model(GAN_in, GAN_out)
        GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer)
        return GAN, GAN_out
    
    # 通过生成数据 预训练判别模型
    def sample_data_and_gen(G, noise_dim=10, n_samples=10000):
        XT = sample_data(n_samples=n_samples)
        XN_noise = np.random.uniform(0, 1, size=[n_samples, noise_dim])
        XN = G.predict(XN_noise)
        X = np.concatenate((XT, XN))
        y = np.zeros((2*n_samples, 2))
        y[:n_samples, 1] = 1
        y[n_samples:, 0] = 1
    
        return X, y
         
    def pretrain(G, D, noise_dim=10, n_samples=10000, batch_size=32):
        X, y = sample_data_and_gen(G, noise_dim=noise_dim, n_samples=n_samples)
        set_trainability(D, True)
        D.fit(X, y, epochs=1, batch_size=batch_size)
        
        
    # 开始交叉训练步骤
    def sample_noise(G, noise_dim=10, n_samples=10000):
        X = np.random.uniform(0, 1, size=[n_samples, noise_dim])
        y = np.zeros((n_samples, 2))
        y[:, 1] = 1
    
        return X, y
        
    def train(GAN, G, D, epochs=500, n_samples=10000, noise_dim=10, batch_size=32, verbose=False, v_freq=50):
        d_loss = []
        g_loss = []
        e_range = range(epochs)
        if verbose:
            e_range = tqdm(e_range)
        
        for epoch in e_range:
            X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim) # 对D进行训练
            set_trainability(D, True)
            d_loss.append(D.train_on_batch(X, y))
            
            X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim) # 对G训练
            set_trainability(D, False)
            g_loss.append(GAN.train_on_batch(X, y))
            
            if verbose and (epoch + 1) % v_freq == 0:
                print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1]))
                
        return d_loss, g_loss
    
    
    ax = pd.DataFrame(np.transpose(sample_data(5))).plot()
    G_in = Input(shape=[10])
    G, G_out = get_generative(G_in)
    G.summary()
    
    D_in = Input(shape=[50])
    D, D_out = get_discriminative(D_in)
    D.summary()
    
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_9 (InputLayer)         (None, 10)                0         
    _________________________________________________________________
    dense_13 (Dense)             (None, 200)               2200      
    _________________________________________________________________
    activation_4 (Activation)    (None, 200)               0         
    _________________________________________________________________
    dense_14 (Dense)             (None, 50)                10050     
    =================================================================
    Total params: 12,250
    Trainable params: 12,250
    Non-trainable params: 0
    _________________________________________________________________
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_10 (InputLayer)        (None, 50)                0         
    _________________________________________________________________
    reshape_4 (Reshape)          (None, 50, 1)             0         
    _________________________________________________________________
    conv1d_4 (Conv1D)            (None, 46, 50)            300       
    _________________________________________________________________
    dropout_4 (Dropout)          (None, 46, 50)            0         
    _________________________________________________________________
    flatten_4 (Flatten)          (None, 2300)              0         
    _________________________________________________________________
    dense_15 (Dense)             (None, 50)                115050    
    _________________________________________________________________
    dense_16 (Dense)             (None, 2)                 102       
    =================================================================
    Total params: 115,452
    Trainable params: 115,452
    Non-trainable params: 0
    _________________________________________________________________
    

    png

    png

    GAN_in = Input([10])
    GAN, GAN_out = make_gan(GAN_in, G, D)
    GAN.summary()
    
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_11 (InputLayer)        (None, 10)                0         
    _________________________________________________________________
    model_9 (Model)              (None, 50)                12250     
    _________________________________________________________________
    model_10 (Model)             (None, 2)                 115452    
    =================================================================
    Total params: 127,702
    Trainable params: 12,250
    Non-trainable params: 115,452
    _________________________________________________________________
    
    pretrain(G, D)
    
    Epoch 1/1
    20000/20000 [==============================] - 3s - loss: 0.0072     
    
    d_loss, g_loss = train(GAN, G, D, verbose=True)
    
    Epoch #50: Generative Loss: 4.41527795791626, Discriminative Loss: 0.6733301877975464
    Epoch #100: Generative Loss: 3.8898046016693115, Discriminative Loss: 0.09901376813650131
    Epoch #150: Generative Loss: 6.2410054206848145, Discriminative Loss: 0.034074194729328156
    Epoch #200: Generative Loss: 5.206066608428955, Discriminative Loss: 0.13078376650810242
    Epoch #250: Generative Loss: 3.5144925117492676, Discriminative Loss: 0.07160962373018265
    Epoch #300: Generative Loss: 3.705162525177002, Discriminative Loss: 0.05893774330615997
    Epoch #350: Generative Loss: 3.511479616165161, Discriminative Loss: 0.09775738418102264
    Epoch #400: Generative Loss: 4.141300678253174, Discriminative Loss: 0.03169865906238556
    Epoch #450: Generative Loss: 3.500260829925537, Discriminative Loss: 0.05957922339439392
    Epoch #500: Generative Loss: 2.9797921180725098, Discriminative Loss: 0.10566817969083786
    
    ax = pd.DataFrame(
        {
            'Generative Loss': g_loss,
            'Discriminative Loss': d_loss,
        }
    ).plot(title='Training loss', logy=True)
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Loss")
    

    png

    png

    N_VIEWED_SAMPLES = 2
    data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
    pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).plot()
    

    png

    png

    N_VIEWED_SAMPLES = 2
    data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
    pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).rolling(5).mean()[5:].plot()
    

    png

    png

    reference

    [1] http://www.rricard.me/machine/learning/generative/adversarial/networks/keras/tensorflow/2017/04/05/gans-part2.html#Imports

  • 相关阅读:
    驱动控制浏览器 和排程算法
    Python简单人脸识别,可调摄像头,基础入门,先简单了解一下吧
    机器学习
    “一拖六”屏幕扩展实战
    Apple iMac性能基准测试
    IDC机房KVM应用案例分析
    突破极限 解决大硬盘上安装Unix新思路
    Domino系统从UNIX平台到windows平台的迁移及备份
    走进集装箱数据中心(附动画详解)
    企业实战之部署Solarwinds Network八部众
  • 原文地址:https://www.cnblogs.com/flyu6/p/7691130.html
Copyright © 2011-2022 走看看