zoukankan      html  css  js  c++  java
  • tensorflow2.0——auto_encoder实战

    import tensorflow as tf
    from tensorflow import keras
    import numpy as np
    import matplotlib.pyplot as plt
    from PIL import Image
    import os
    from tensorflow.keras import Sequential, layers
    import sys
    
    #  # 设置相关底层配置
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    
    # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    # os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 使用第2块gpu
    
    #   超参数
    h_dim = 20
    batchsz = 512
    learn_rate = 1e-3
    
    (x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
    x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
    print('x_train.shape:', x_train.shape)
    train_db = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batchsz * 5).batch(batchsz)
    test_db = tf.data.Dataset.from_tensor_slices(x_test).batch(batchsz)
    
    def my_save_img(data,name):
        save_img_path = './img_dir/AE_img/{}.jpg'.format(name)
        new_img = np.zeros((280,280))
        for index,each_img in enumerate(data[:100]):
            row_start = int(index/10) * 28
            col_start = (index%10)*28
            # print(index,row_start,col_start)
            new_img[row_start:row_start+28,col_start:col_start+28] = each_img
    
        plt.imsave(save_img_path,new_img)
    
    # plt.imshow(new_img)
    # plt.show()
    
    
    # sys.exit(2)
    
    
    #   打印数据图
    # for i in range(16):
    #     plt.subplot(4,4,i+1)
    #     plt.imshow(np.reshape(x_train[i],(28,28,1)))
    # plt.show()
    
    class AE(keras.Model):
        def __init__(self):
            super(AE, self).__init__()
    
            #   Encoders
            self.encoder = Sequential([
                layers.Dense(256, activation=tf.nn.relu),
                layers.Dense(128, activation=tf.nn.relu),
                layers.Dense(h_dim)
            ])
    
            #   Decoders
            self.decoder = Sequential([
                layers.Dense(128, activation=tf.nn.relu),
                layers.Dense(256, activation=tf.nn.relu),
                layers.Dense(28 * 28),
            ])
    
        def call(self, inputs, training=None, mask=None):
            #   [b,784] => [b,h_dim]
            x = self.encoder(inputs)
    
            #   [b,h_dim] => [b,784]
            x = self.decoder(x)
    
            return x
    
    
    my_model = AE()
    my_model.build(input_shape=(None, 784))
    my_model.summary()
    
    opt = tf.optimizers.Adam(lr=learn_rate)
    for epoch in range(50):
        for step, x in enumerate(train_db):
            #   [b,28,28] => [b,784]    打平
            x = tf.reshape(x, [-1, 784])
    
            with tf.GradientTape() as tape:
                out = my_model(x)
                my_loss = tf.losses.binary_crossentropy(x, out, from_logits=True)
                # my_loss = tf.losses.mean_squared_error(x,out)
                my_loss = tf.reduce_mean(my_loss)
            grads = tape.gradient(my_loss, my_model.trainable_variables)
            opt.apply_gradients(zip(grads, my_model.trainable_variables))
            if step % 100 == 0:
                print(epoch,step,float(my_loss))
    
            #   evaluation
            x = next(iter(test_db))
            my_save_img(x, '{}_label'.format(epoch))
            x = tf.reshape(x, [-1, 784])
            logits  = my_model(x)
            x_hat = tf.sigmoid(logits)        #   loss用binary
            # x_hat = logits                      #   loss用MSE
            x_hat = tf.reshape(x_hat,[-1,28,28])
            my_save_img(x_hat,'{}_pre'.format(epoch))
  • 相关阅读:
    如何在android studio上加入OpenCV库
    c++ overload 、override、overwrite
    学习笔记:linux之文件空洞
    windows 编译安装PROJ.4
    RDD:基于内存的集群计算容错抽象
    用Scala语言轻松开发多线程、分布式以及集群式程序
    scala 读雷达数据文件,生成png
    linux下virtualBox挂载物理磁盘,启动第二块硬盘中的系统
    shell检查网络出现异常、僵尸进程、内存过低后,自动重启
    cordova 5.0版本说明
  • 原文地址:https://www.cnblogs.com/cxhzy/p/14163824.html
Copyright © 2011-2022 走看看