zoukankan      html  css  js  c++  java
  • 用TensorFlow2.0构建分类模型对数据集fashion_mnist进行分类

    import tensorflow as tf
    import tensorflow.keras as keras
    import matplotlib.pyplot as plt
    import pandas as pd
    
    #加载数据
    fasion_mnist = keras.datasets.fashion_mnist
    (X_train_all, y_train_all), (X_test, y_test) = fasion_mnist.load_data()
    X_valid, X_train = X_train_all[:5000], X_train_all[5000:]
    y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
    
    #构建模型
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=[28, 28]),
        keras.layers.Dense(300, activation='sigmoid'),
        keras.layers.Dense(100, activation='sigmoid'),
        keras.layers.Dense(10, activation='sigmoid')
    ])
    
    model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    history = model.fit(x=X_train, y = y_train, epochs=10, validation_data=(X_valid,y_valid))
    
    print(history)
    
    #画出训练结果
    pd.DataFrame(history.history).plot()
    plt.show()

    若那啥模型的标签是一个类别索引,则

    loss='sparse_categorical_crossentropy'

    若模型的标签也是一个向量,表示属于每一类的概率,则
    loss='categorical_crossentropy'
    可以认为加了sparse之后对标签y进行了one_hot编码
  • 相关阅读:
    strncpy (Strings) – C 中文开发手册
    HTML track label 属性
    Java面试题:常用的Web服务器有哪些?
    鲲鹏920上安装ovs
    基于AC控制器+VXLAN解决方案
    二层MAC学习及BUM报文转发
    基于mac表的vxlan转发
    Agile Controller vxlan
    设置鲲鹏916/920通过pxe安装os
    ovs-vxlan&vlan
  • 原文地址:https://www.cnblogs.com/loubin/p/12573622.html
Copyright © 2011-2022 走看看