zoukankan      html  css  js  c++  java
  • 机器学习-识别手写数字0-9

     1 import os
     2 os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # to hidden the messages from tensorflow
     3 from tensorflow import keras
     4 from  tensorflow.keras import layers
     5 import numpy as np
     6 import matplotlib.pyplot as plt
     7 # from tensorflow.keras.datasets import mnist
     8 # mnist is the handwriting number dataset 0-9
     9 import sys
    10 
    11 def load_mnist(path):
    12     file=np.load(path)
    13     x_train,y_train=file['x_train'],file['y_train']
    14     x_test,y_test=file['x_test'],file['y_test']
    15     file.close()
    16     return (x_train, y_train), (x_test, y_test)
    17 
    18 
    19 def check(N,pages):
    20     idx=0
    21     for page in range(pages):
    22         for i in range(N):
    23             for j in range(N):
    24                 num = i*N+j
    25                 plt.subplot(N,N,num+1)
    26                 plt.imshow(x_train[num+idx],cmap=plt.get_cmap('gray'))
    27         idx+=N*N
    28         plt.show()
    29         # show the plot
    30 
    31 
    32 path="C:/Users/77007/Desktop/python/pythonProject1/mnist.npz"
    33 (x_train,y_train),(x_test, y_test)=load_mnist(path)
    34 # print(x_train.shape) # 60000 张图片 pix 28*28
    35 # print(y_train.shape) # 60000 个结果对应0-9其中一个
    36 x_train=x_train.reshape(-1,784).astype("float32")/255.0
    37 x_test=x_test.reshape(-1,784).astype("float32")/255.0
    38 # print(x_train.shape)
    39 # print(x_test.shape)
    40 
    41 # Sequential API (convenient, not flexible)
    42 model=keras.Sequential(
    43     [
    44         keras.Input(shape=(28*28)), # for print the model
    45         layers.Dense(512,activation='relu'),
    46         layers.Dense(256,activation='relu'),
    47         layers.Dense(10),
    48     ]
    49 )
    50 # another definition method
    51 ''' 
    52 model=keras.Sequential()
    53 model.add(keras.Input(shape=784))
    54 model.add(layers.Dense(512,activation='relu'))
    55 model.add(layers.Dense(256,activation='relu'))
    56 model.add(layers.Dense(10))
    57 '''
    58 # print(model.summary())
    59 # sys.exit()
    60 
    61 # Functional API (more flexible)
    62 inputs=keras.Input(shape=(784))
    63 x=layers.Dense(512,activation='relu',name='first_layer')(inputs)
    64 x=layers.Dense(256,activation='relu',name='second_layer')(x)
    65 outputs=layers.Dense(10,activation='softmax')(x)
    66 model=keras.Model(inputs=inputs,outputs=outputs) # change the model (OR you can comment this line using the model upward)
    67 # print(model.summary())
    68 # for printing summary of our model, we can add name='xxx' feature of each layers
    69 
    70 model.compile(
    71     loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    72     # Sequential API using True, Functional API using False
    73     optimizer=keras.optimizers.Adam(lr=0.001),
    74     # learing rate=0.001
    75     metrics=["accuracy"],
    76 )
    77 
    78 model.fit(x_train,y_train,batch_size=32,epochs=5,verbose=2)
    79 model.evaluate(x_test,y_test,batch_size=32,verbose=2)
     1 Train on 60000 samples
     2 Epoch 1/5
     3 60000/60000 - 6s - loss: 0.1865 - accuracy: 0.9425
     4 Epoch 2/5
     5 60000/60000 - 5s - loss: 0.0800 - accuracy: 0.9749
     6 Epoch 3/5
     7 60000/60000 - 5s - loss: 0.0541 - accuracy: 0.9826
     8 Epoch 4/5
     9 60000/60000 - 5s - loss: 0.0395 - accuracy: 0.9872
    10 Epoch 5/5
    11 60000/60000 - 5s - loss: 0.0341 - accuracy: 0.9890
    12 10000/1 - 0s - loss: 0.0392 - accuracy: 0.9792
    ~~Jason_liu O(∩_∩)O
  • 相关阅读:
    jquery ajax 跨域请求【原】
    纯js异步无刷新请求(只支持IE)【原】
    正则表达式高级用法【原】
    所有HTTP请求参数及报文查看SERVLET【原】
    AES加密【转】
    Object.prototype.toString.call() 区分对象类型
    js中的preventDefault与stopPropagation详解
    在项目中如何利用分页插件呢?
    Iframe 在项目中的使用总结
    在项目中那个少用if else 语句,精简代码,便于维护的方法(1)
  • 原文地址:https://www.cnblogs.com/JasonCow/p/14524922.html
Copyright © 2011-2022 走看看