zoukankan      html  css  js  c++  java
  • Auto-Encoders实战

    Outline

    • Auto-Encoder

    • Variational Auto-Encoders

    Auto-Encoder

    51-AutoEncoders实战-autoencoder.jpg

    创建编解码器

    import os
    import tensorflow as tf
    import numpy as np
    from tensorflow import keras
    from tensorflow.keras import Sequential, layers
    from PIL import Image
    from matplotlib import pyplot as plt
    
    tf.random.set_seed(22)
    np.random.seed(22)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    assert tf.__version__.startswith('2.')
    
    
    def save_images(imgs, name):
        new_im = Image.new('L', (280, 280))
    
        index = 0
        for i in range(0, 280, 28):
            for j in range(0, 280, 28):
                im = imgs[index]
                im = Image.fromarray(im, mode='L')
                new_im.paste(im, (i, j))
                index += 1
    
        new_im.save(name)
    
    
    h_dim = 20  # 784降维20维
    batchsz = 512
    lr = 1e-3
    
    (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
    x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
        np.float32) / 255.
    # we do not need label
    train_db = tf.data.Dataset.from_tensor_slices(x_train)
    train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
    test_db = tf.data.Dataset.from_tensor_slices(x_test)
    test_db = test_db.batch(batchsz)
    
    print(x_train.shape, y_train.shape)
    print(x_test.shape, y_test.shape)
    
    
    class AE(keras.Model):
        def __init__(self):
            super(AE, self).__init__()
    
            # Encoders
            self.encoder = Sequential([
                layers.Dense(256, activation=tf.nn.relu),
                layers.Dense(128, activation=tf.nn.relu),
                layers.Dense(h_dim)
            ])
    
            # Decoders
            self.decoder = Sequential([
                layers.Dense(128, activation=tf.nn.relu),
                layers.Dense(256, activation=tf.nn.relu),
                layers.Dense(784)
            ])
    
        def call(self, inputs, training=None):
            # [b,784] ==> [b,19]
            h = self.encoder(inputs)
    
            # [b,10] ==> [b,784]
            x_hat = self.decoder(h)
    
            return x_hat
    
    
    model = AE()
    model.build(input_shape=(None, 784))  # tensorflow尽量用元组
    model.summary()
    
    (60000, 28, 28) (60000,)
    (10000, 28, 28) (10000,)
    Model: "ae"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    sequential (Sequential)      multiple                  236436    
    _________________________________________________________________
    sequential_1 (Sequential)    multiple                  237200    
    =================================================================
    Total params: 473,636
    Trainable params: 473,636
    Non-trainable params: 0
    _________________________________________________________________
    

    训练

    optimizer = tf.optimizers.Adam(lr=lr)
    
    for epoch in range(10):
    
        for step, x in enumerate(train_db):
    
            # [b,28,28]==>[b,784]
            x = tf.reshape(x, [-1, 784])
    
            with tf.GradientTape() as tape:
                x_rec_logits = model(x)
    
                rec_loss = tf.losses.binary_crossentropy(x,
                                                         x_rec_logits,
                                                         from_logits=True)
                rec_loss = tf.reduce_min(rec_loss)
    
            grads = tape.gradient(rec_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
            if step % 100 == 0:
                print(epoch, step, float(rec_loss))
                
                # evaluation
    
            x = next(iter(test_db))
            logits = model(tf.reshape(x, [-1, 784]))
            x_hat = tf.sigmoid(logits)
            # [b,784]==>[b,28,28]
            x_hat = tf.reshape(x_hat, [-1, 28, 28])
    
            # [b,28,28] ==> [2b,28,28]
            x_concat = tf.concat([x, x_hat], axis=0)
            # x_concat = x  # 原始图片
            x_concat = x_hat
            x_concat = x_concat.numpy() * 255.
            x_concat = x_concat.astype(np.uint8)  # 保存为整型
            if not os.path.exists('ae_images'):
                os.mkdir('ae_images')
            save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
    
    0 0 0.09717604517936707
    0 100 0.12493347376585007
    1 0 0.09747321903705597
    1 100 0.12291513383388519
    2 0 0.10048121958971024
    2 100 0.12292417883872986
    3 0 0.10093794018030167
    3 100 0.12260882556438446
    4 0 0.10006923228502274
    4 100 0.12275046110153198
    5 0 0.0993042066693306
    5 100 0.12257824838161469
    6 0 0.0967678651213646
    6 100 0.12443818897008896
    7 0 0.0965462476015091
    7 100 0.12179268896579742
    8 0 0.09197664260864258
    8 100 0.12110235542058945
    9 0 0.0913471132516861
    9 100 0.12342415750026703
    
    
    
  • 相关阅读:
    C++ primer plus读书笔记——第6章 分支语句和逻辑运算符
    C++ primer plus读书笔记——第7章 函数——C++的编程模块
    C++ primer plus读书笔记——第5章 循环和关系表达式
    C++ primer plus读书笔记——第4章 复合类型
    C++ primer plus读书笔记——第3章 处理数据
    C++ primer plus读书笔记——第2章 开始学习C++
    10款好用到爆的Vim插件,你知道几个?
    程序员最讨厌的100件事,瞬间笑喷了,哈哈~~
    20 个最常用的 Git 命令用法说明及示例
    史上最全的Nginx配置参数中文说明
  • 原文地址:https://www.cnblogs.com/abdm-989/p/14123449.html
Copyright © 2011-2022 走看看