zoukankan      html  css  js  c++  java
  • keras—神经网络CNN—MNIST手写数字识别

     1 from keras.datasets import mnist
     2 from keras.utils import np_utils
     3 from plot_image_1 import plot_image_1
     4 from plot_prediction_1 import plot_image_labels_prediction_1
     5 from show_train_history import show_train_history
     6 import numpy as np
     7 import pandas as pd
     8 from keras.models import Sequential
     9 from keras.layers import Dense,Dropout,Flatten,Conv2D,MaxPooling2D
    10 np.random.seed(10)
    11 (x_Train,y_Train),(x_Test,y_Test)=mnist.load_data()
    12 print('train data=',len(x_Train))
    13 print('test data=',len(x_Test))
    14 print('x_train_image:',x_Train.shape)
    15 print('y_train_label:',y_Train.shape)
    16 x_Train4D=x_Train.reshape(x_Train.shape[0],28,28,1).astype('float32')
    17 x_Test4D=x_Test.reshape(x_Test.shape[0],28,28,1).astype('float32')
    18 x_Train4D_normalize=x_Train4D/255
    19 x_Test4D_normalize=x_Test4D/255
    20 y_TrainOneHot=np_utils.to_categorical(y_Train)
    21 y_TestOneHot=np_utils.to_categorical(y_Test)
    22 model=Sequential()
    23 model.add(Conv2D(filters=16,
    24                  kernel_size=(5,5),
    25                  padding='same',
    26                  input_shape=(28,28,1),
    27                  activation='relu'))
    28 model.add(MaxPooling2D(pool_size=(2,2)))
    29 model.add(Conv2D(filters=36,
    30                  kernel_size=(5,5),
    31                  padding='same',
    32                  activation='relu'))
    33 model.add(MaxPooling2D(pool_size=(2,2)))
    34 model.add(Dropout(0.25))
    35 model.add(Flatten())
    36 model.add(Dense(128,activation='relu'))
    37 model.add(Dropout(0.5))
    38 model.add(Dense(10,activation='softmax'))
    39 print(model.summary())
    40 model.compile(loss='categorical_crossentropy',
    41               optimizer='adam',metrics=['accuracy'])
    42 train_history=model.fit(x=x_Train4D_normalize,
    43                         y=y_TrainOneHot,validation_split=0.2,
    44                         epochs=5,batch_size=300,verbose=2)
    45 show_train_history(train_history,'acc','val_acc')
    46 show_train_history(train_history,'loss','val_loss')
    47 scores=model.evaluate(x_Test4D_normalize,y_TestOneHot)
    48 print()
    49 print('accuracy',scores[1])
    50 prediction=model.predict_classes(x_Test4D_normalize)
    51 print("prediction[:10]",prediction[:10])
    52 plot_image_labels_prediction_1(x_Test,y_Test,prediction,idx=0)
    53 pd.crosstab(y_Test,prediction,rownames=['label'],colnames=['predict'])
    1 import matplotlib.pyplot as plt
    2 def plot_image_1(image):
    3     fig=plt.gcf()
    4     fig.set_size_inches(2,2)
    5     plt.imshow(image,cmap='binary')
    6     plt.show()
     1 import matplotlib.pyplot as plt
     2 def plot_image_labels_prediction_1(image,labels,prediction,idx,num=10):
     3     fig=plt.gcf()
     4     fig.set_size_inches(12,14)
     5     if num>25:num=25
     6     for i in range(0,num):
     7         ax=plt.subplot(5,5,i+1)
     8         ax.imshow(image[idx],cmap='binary')
     9         title="label="+str(labels[idx])
    10         if len(prediction)>0:
    11             title+=",predict="+str(prediction[idx])
    12         ax.set_title(title,fontsize=10)
    13         ax.set_xticks([]);ax.set_yticks([])
    14         idx+=1
    15     plt.show()
    1 import matplotlib.pyplot as plt
    2 def show_train_history(train_history,train,validation):
    3     plt.plot(train_history.history[train])
    4     plt.plot(train_history.history[validation])
    5     plt.title('Train History')
    6     plt.ylabel(train)
    7     plt.xlabel('Epoch')
    8     plt.legend(['train','validation'],loc='upper left')    #显示左上角标签
    9     plt.show()

    萍水相逢逢萍水,浮萍之水水浮萍!
  • 相关阅读:
    C#嵌套类
    C#8.0接口默认实现特性
    asp.net Server.Transfer
    clickjacking 攻击
    frame标签和frameset
    javascript打开窗口
    Linux 之 LNMP服务器搭建-PHP
    Linux 之 LNMP服务器搭建-前期准备
    Linux 之 LNMP服务器搭建-Nginx
    Linux 之 Samba服务器
  • 原文地址:https://www.cnblogs.com/AIBigTruth/p/9735670.html
Copyright © 2011-2022 走看看