zoukankan      html  css  js  c++  java
  • 使用Keras编写GAN的入门

    使用Keras编写GAN的入门

    Time: 2017-5-31


    前言

    主要参考了网页[1]的教程,同时主要算法来自Ian J. Goodfellow 的论文,算法如下:

    gan

    gan

    代码

    %matplotlib inline
    import numpy as np
    import pandas as pd
    
    from keras.models import Model
    from keras.layers import Dense, Activation, Input, Reshape
    from keras.layers import Conv1D, Flatten, Dropout
    from keras.optimizers import SGD, Adam
    
    
    from tqdm import tqdm_notebook as tqdm  # 进度条
    
    
    # 生成随机正弦曲线的数据
    def sample_data(n_samples=10000, x_vals=np.arange(0, 5, .1), max_offset=1000, mul_range=[1, 2]):
        vectors = []
        for i in range(n_samples):
            offset = np.random.random() * max_offset
            mul = mul_range[0] + np.random.random() * (mul_range[1] - mul_range[0])
            vectors.append(np.sin(offset + x_vals * mul) / 2 + .5)
            
        return np.array(vectors)
        
    # 创建生成模型
    def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3):
        x = Dense(dense_dim)(G_in)
        x = Activation('tanh')(x)
        G_out = Dense(out_dim, activation='tanh')(x)
        G = Model(G_in, G_out)
        opt = SGD(lr=lr)
        
        G.compile(loss='binary_crossentropy', optimizer=opt)
        
        return G, G_out
        
    # 创建判别模型
    def get_discriminative(D_in, lr=1e-3, drate = .25, n_channels=50, conv_sz=5, leak=.2):
        x = Reshape((-1, 1))(D_in)
        x = Conv1D(n_channels, conv_sz, activation='relu')(x)
        x = Dropout(drate)(x)
        x = Flatten()(x)
        x = Dense(n_channels)(x)
        D_out = Dense(2, activation='sigmoid')(x)
        D = Model(D_in, D_out)
        dopt = Adam(lr=lr)
        D.compile(loss='binary_crossentropy', optimizer=dopt)
        
        return D, D_out
    
        
        
    def set_trainability(model, trainable=False):
        model.trainable = trainable
        for layer in model.layers:
            layer.trainable = trainable
            
    def make_gan(GAN_in, G, D):
        set_trainability(D, False)
        x = G(GAN_in)
        GAN_out = D(x)
        GAN = Model(GAN_in, GAN_out)
        GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer)
        return GAN, GAN_out
    
    # 通过生成数据 预训练判别模型
    def sample_data_and_gen(G, noise_dim=10, n_samples=10000):
        XT = sample_data(n_samples=n_samples)
        XN_noise = np.random.uniform(0, 1, size=[n_samples, noise_dim])
        XN = G.predict(XN_noise)
        X = np.concatenate((XT, XN))
        y = np.zeros((2*n_samples, 2))
        y[:n_samples, 1] = 1
        y[n_samples:, 0] = 1
    
        return X, y
         
    def pretrain(G, D, noise_dim=10, n_samples=10000, batch_size=32):
        X, y = sample_data_and_gen(G, noise_dim=noise_dim, n_samples=n_samples)
        set_trainability(D, True)
        D.fit(X, y, epochs=1, batch_size=batch_size)
        
        
    # 开始交叉训练步骤
    def sample_noise(G, noise_dim=10, n_samples=10000):
        X = np.random.uniform(0, 1, size=[n_samples, noise_dim])
        y = np.zeros((n_samples, 2))
        y[:, 1] = 1
    
        return X, y
        
    def train(GAN, G, D, epochs=500, n_samples=10000, noise_dim=10, batch_size=32, verbose=False, v_freq=50):
        d_loss = []
        g_loss = []
        e_range = range(epochs)
        if verbose:
            e_range = tqdm(e_range)
        
        for epoch in e_range:
            X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim) # 对D进行训练
            set_trainability(D, True)
            d_loss.append(D.train_on_batch(X, y))
            
            X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim) # 对G训练
            set_trainability(D, False)
            g_loss.append(GAN.train_on_batch(X, y))
            
            if verbose and (epoch + 1) % v_freq == 0:
                print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1]))
                
        return d_loss, g_loss
    
    
    ax = pd.DataFrame(np.transpose(sample_data(5))).plot()
    G_in = Input(shape=[10])
    G, G_out = get_generative(G_in)
    G.summary()
    
    D_in = Input(shape=[50])
    D, D_out = get_discriminative(D_in)
    D.summary()
    
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_9 (InputLayer)         (None, 10)                0         
    _________________________________________________________________
    dense_13 (Dense)             (None, 200)               2200      
    _________________________________________________________________
    activation_4 (Activation)    (None, 200)               0         
    _________________________________________________________________
    dense_14 (Dense)             (None, 50)                10050     
    =================================================================
    Total params: 12,250
    Trainable params: 12,250
    Non-trainable params: 0
    _________________________________________________________________
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_10 (InputLayer)        (None, 50)                0         
    _________________________________________________________________
    reshape_4 (Reshape)          (None, 50, 1)             0         
    _________________________________________________________________
    conv1d_4 (Conv1D)            (None, 46, 50)            300       
    _________________________________________________________________
    dropout_4 (Dropout)          (None, 46, 50)            0         
    _________________________________________________________________
    flatten_4 (Flatten)          (None, 2300)              0         
    _________________________________________________________________
    dense_15 (Dense)             (None, 50)                115050    
    _________________________________________________________________
    dense_16 (Dense)             (None, 2)                 102       
    =================================================================
    Total params: 115,452
    Trainable params: 115,452
    Non-trainable params: 0
    _________________________________________________________________
    

    png

    png

    GAN_in = Input([10])
    GAN, GAN_out = make_gan(GAN_in, G, D)
    GAN.summary()
    
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_11 (InputLayer)        (None, 10)                0         
    _________________________________________________________________
    model_9 (Model)              (None, 50)                12250     
    _________________________________________________________________
    model_10 (Model)             (None, 2)                 115452    
    =================================================================
    Total params: 127,702
    Trainable params: 12,250
    Non-trainable params: 115,452
    _________________________________________________________________
    
    pretrain(G, D)
    
    Epoch 1/1
    20000/20000 [==============================] - 3s - loss: 0.0072     
    
    d_loss, g_loss = train(GAN, G, D, verbose=True)
    
    Epoch #50: Generative Loss: 4.41527795791626, Discriminative Loss: 0.6733301877975464
    Epoch #100: Generative Loss: 3.8898046016693115, Discriminative Loss: 0.09901376813650131
    Epoch #150: Generative Loss: 6.2410054206848145, Discriminative Loss: 0.034074194729328156
    Epoch #200: Generative Loss: 5.206066608428955, Discriminative Loss: 0.13078376650810242
    Epoch #250: Generative Loss: 3.5144925117492676, Discriminative Loss: 0.07160962373018265
    Epoch #300: Generative Loss: 3.705162525177002, Discriminative Loss: 0.05893774330615997
    Epoch #350: Generative Loss: 3.511479616165161, Discriminative Loss: 0.09775738418102264
    Epoch #400: Generative Loss: 4.141300678253174, Discriminative Loss: 0.03169865906238556
    Epoch #450: Generative Loss: 3.500260829925537, Discriminative Loss: 0.05957922339439392
    Epoch #500: Generative Loss: 2.9797921180725098, Discriminative Loss: 0.10566817969083786
    
    ax = pd.DataFrame(
        {
            'Generative Loss': g_loss,
            'Discriminative Loss': d_loss,
        }
    ).plot(title='Training loss', logy=True)
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Loss")
    

    png

    png

    N_VIEWED_SAMPLES = 2
    data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
    pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).plot()
    

    png

    png

    N_VIEWED_SAMPLES = 2
    data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
    pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).rolling(5).mean()[5:].plot()
    

    png

    png

    reference

    [1] http://www.rricard.me/machine/learning/generative/adversarial/networks/keras/tensorflow/2017/04/05/gans-part2.html#Imports

  • 相关阅读:
    超多sql分步骤类型题解
    sql 高级函数
    sql 每天下单的老客数量
    sql
    面试-JAVA常见回答
    查询员工的累计薪水
    背包问题模板
    动态规划理解合集
    SQL语句统计每天、每月、每年的 数据
    机考刷机-树相关
  • 原文地址:https://www.cnblogs.com/flyu6/p/7691130.html
Copyright © 2011-2022 走看看