zoukankan      html  css  js  c++  java
  • tf2 fashion_mnist 入门

    学习使用tf2

    视频教程传送门

    知识点:

    loss="sparse_categorical_crossentropy"

    这个 sparse是对y进行one-hot操作,如果y已经做过one-hot,则使用 categorical_crossentropy.

    #!/usr/bin/env python
    # coding: utf-8
    
    # In[1]:
    
    
    import tensorflow as tf
    import tensorflow.keras as k
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    # In[21]:
    
    
    fashion_mnist = k.datasets.fashion_mnist
    (x_train,y_train),(x_test,y_test)=fashion_mnist.load_data()
    x_train,x_valid = x_train[:5000],x_train[5000:]
    y_train,y_valid= y_train[:5000],y_train[5000:]
    
    
    # In[7]:
    
    
    def show_single_img(img):
        plt.imshow(img,cmap="binary")
        plt.show()
    
    
    # In[8]:
    
    
    show_single_img(x_vaild[0])
    
    
    # In[16]:
    
    
    def show_imgs(n_rows,n_cols,x,y,classes):
        plt.figure(figsize=(n_rows*1.4,n_cols*1.6))
        for row in range(n_rows):
            for col in range(n_cols):
                index = row * n_cols + col
                plt.subplot(n_rows,n_cols,index+1)
                plt.imshow(x[index],cmap="binary")
                plt.title(classes[y[index]])
                plt.axis("off")
    classes=['T-shirt/top','Trouser','Pullover','Dress','Coat',
             'Sandal','Shirt','Sneaker','Bag','Ankle boot']
    
    
    # In[17]:
    
    
    show_imgs(1,5,x_train[:5],y_train[:5],classes)
    
    
    # In[24]:
    
    
    #build the model
    model =k.Sequential()
    model.add(k.layers.Flatten(input_shape=[28,28]))
    model.add(k.layers.Dense(300,activation="relu"))
    model.add(k.layers.Dense(100,activation="relu"))
    model.add(k.layers.Dense(10,activation="softmax"))
    
    model.compile(loss="sparse_categorical_crossentropy",
                 optimizer="adam",
                 metrics=["accuracy"])
    
    
    # In[25]:
    
    
    history=model.fit(x_train,y_train,epochs=10,
             validation_data=(x_valid,y_valid))
    
    
    # In[27]:
    
    
    import pandas as pd
    def plot_curve(history):
        pd.DataFrame(history.history).plot(figsize=(8,5))
        plt.grid(True)
        plt.gca().set_ylim(0,1)
        plt.show()
    plot_curve(history)
    
    
    # In[ ]:
    View Code

    适用sklearn对数据集进行归一化操作

    #!/usr/bin/env python
    # coding: utf-8
    
    # In[5]:
    
    
    import tensorflow as tf
    import tensorflow.keras as k
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    # In[6]:
    
    
    fashion_mnist = k.datasets.fashion_mnist
    (x_train,y_train),(x_test,y_test)=fashion_mnist.load_data()
    x_train,x_valid = x_train[:5000],x_train[5000:]
    y_train,y_valid= y_train[:5000],y_train[5000:]
    
    
    # In[7]:
    
    
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    
    
    # In[8]:
    
    
    #build the model
    model =k.Sequential()
    model.add(k.layers.Flatten(input_shape=[28,28]))
    model.add(k.layers.Dense(300,activation="relu"))
    model.add(k.layers.Dense(100,activation="relu"))
    model.add(k.layers.Dense(10,activation="softmax"))
    
    model.compile(loss="sparse_categorical_crossentropy",
                 optimizer="adam",
                 metrics=["accuracy"])
    
    
    # In[9]:
    
    
    history=model.fit(x_train,y_train,epochs=10,
             validation_data=(x_valid,y_valid))
    
    
    # In[10]:
    
    
    import pandas as pd
    def plot_curve(history):
        pd.DataFrame(history.history).plot(figsize=(8,5))
        plt.grid(True)
        plt.gca().set_ylim(0,1)
        plt.show()
    plot_curve(history)
    
    
    # In[ ]:
    View Code
  • 相关阅读:
    paper 89:视频图像去模糊常用处理方法
    paper 88:人脸检测和识别的Web服务API
    paper 87:行人检测资源(下)代码数据【转载,以后使用】
    paper 86:行人检测资源(上)综述文献【转载,以后使用】
    paper 85:机器统计学习方法——CART, Bagging, Random Forest, Boosting
    paper 84:机器学习算法--随机森林
    paper 83:前景检测算法_1(codebook和平均背景法)
    paper 82:边缘检测的各种微分算子比较(Sobel,Robert,Prewitt,Laplacian,Canny)
    paper 81:HDR成像技术
    paper 80 :目标检测的图像特征提取之(一)HOG特征
  • 原文地址:https://www.cnblogs.com/superxuezhazha/p/12257140.html
Copyright © 2011-2022 走看看