zoukankan      html  css  js  c++  java
  • keras使用AutoEncoder对mnist数据降维

    import keras
    import matplotlib.pyplot as plt
    from keras.datasets import mnist
    
    (x_train, _), (x_test, y_test) = mnist.load_data()
    
    x_train = x_train.astype('float32') / 255
    x_test = x_test.astype('float32') / 255
    x_train = x_train.reshape(x_train.shape[0], -1)
    x_test = x_test.reshape(x_test.shape[0], -1)
    encoding_dim = 2
    
    encoder = keras.models.Sequential([
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(32, activation='relu'),
        keras.layers.Dense(8, activation='relu'),
        keras.layers.Dense(encoding_dim)
    ])
    
    decoder = keras.models.Sequential([
        keras.layers.Dense(8, activation='relu'),
        keras.layers.Dense(32, activation='relu'),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(784, activation='tanh')
    ])
    
    AutoEncoder = keras.models.Sequential([
        encoder,
        decoder
    ])
    AutoEncoder.compile(optimizer='adam', loss='mse')
    AutoEncoder.fit(x_train, x_train, epochs=10, batch_size=256)
    
    predict = encoder.predict(x_test)
    plt.scatter(predict[:, 0], predict[:, 1], c=y_test)
    plt.show()
    

      

     

    将数据降到两维以后,得到的图像如下:

  • 相关阅读:
    oracle增加表空间大小
    oracle日常查看
    oracle报错ORA-01653 dba_free_space中没有该表空间
    大数据hadoop生态圈
    1104报表背景知识
    db2和oracle字段类型对比
    weblogic 内存配置
    java内存配置举例
    java内存和linux关系
    PHP连接Redis操作函数
  • 原文地址:https://www.cnblogs.com/yytxdy/p/11831049.html
Copyright © 2011-2022 走看看