zoukankan      html  css  js  c++  java
  • 用TensorFlow2.0构建分类模型对数据集fashion_mnist进行分类

    import tensorflow as tf
    import tensorflow.keras as keras
    import matplotlib.pyplot as plt
    import pandas as pd
    
    #加载数据
    fasion_mnist = keras.datasets.fashion_mnist
    (X_train_all, y_train_all), (X_test, y_test) = fasion_mnist.load_data()
    X_valid, X_train = X_train_all[:5000], X_train_all[5000:]
    y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
    
    #构建模型
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=[28, 28]),
        keras.layers.Dense(300, activation='sigmoid'),
        keras.layers.Dense(100, activation='sigmoid'),
        keras.layers.Dense(10, activation='sigmoid')
    ])
    
    model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    history = model.fit(x=X_train, y = y_train, epochs=10, validation_data=(X_valid,y_valid))
    
    print(history)
    
    #画出训练结果
    pd.DataFrame(history.history).plot()
    plt.show()

    若那啥模型的标签是一个类别索引,则

    loss='sparse_categorical_crossentropy'

    若模型的标签也是一个向量,表示属于每一类的概率,则
    loss='categorical_crossentropy'
    可以认为加了sparse之后对标签y进行了one_hot编码
  • 相关阅读:
    flex+java+blazeds 多通道好文
    如何保持PC客户端一直处于登录状态
    函数进阶
    数据类型扩展
    python编码规范
    xpath轴定位
    IDEA Terminal 运行mvn命令报该命令不是内部命令
    java环境安装Firefox驱动/IE驱动
    java环境添加chrome驱动
    java安装selenium webdriver环境
  • 原文地址:https://www.cnblogs.com/loubin/p/12573622.html
Copyright © 2011-2022 走看看