zoukankan      html  css  js  c++  java
  • CNN实战--mnist

    CNN实战--mnist

    dataprocessing

    我一般把数据处理单独写一个函数

    因为网上大多数都是直接在线下载做学习,导致与实际应用的情况不相符,所以我这是直接下载下来并读取,处理数据

    这个数据类型文档说的很清楚

    是图片二进制存储的(图片大小28*28),并且开头有一个magic num (需要跳过它)

    不知道跳几位的可以多尝试一下不同的offset输出长度看是不是整除

    具体数据处理可以看这个(虽然网上一搜就搜到了)

    def read_data():
    
        with open('./t10k-labels.idx1-ubyte','rb') as f:
            y_test=np.frombuffer(f.read(),np.uint8,offset=8)
            y_test=tf.convert_to_tensor(y_test,tf.int32)
            # offset代表从第几个byte后面开始读取,0则是从头开始读 1byte=8bit
            # y_test=tf.one_hot(y_test,10)
    
        with open('./train-labels.idx1-ubyte','rb') as f:
            y_train=np.frombuffer(f.read(),np.uint8,offset=8)
            y_train=tf.convert_to_tensor(y_train,tf.int32)
            # 1*10000
            # y_train=tf.one_hot(y_train,10)
    
        with open('./t10k-images.idx3-ubyte', 'rb') as f:
            x_test = np.frombuffer(f.read(), np.uint8,offset=16).reshape(len(y_test), 28, 28,1)
            x_test=tf.convert_to_tensor(x_test,tf.float32)/255
        # #502098=28*28*60000
    
        with open('./train-images.idx3-ubyte', 'rb') as f:
            x_train = np.frombuffer(f.read(), np.uint8,offset=16).reshape(len(y_train),28,28,1)
            x_train=tf.convert_to_tensor(x_train,dtype=tf.float32)/255
        #78400=28*28*10000
        return x_train,y_train,x_test,y_test
    

    train_model

    #-*- coding:utf-8 -*-
    # @Author : Dummerfu
    # @Time : 2020/4/20 21:42
    import tensorflow as tf
    import data_processing
    import numpy as np
    import os
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    if __name__ == '__main__':
    
        # x: [60k, 28, 28,1], [10k, 28, 28,1]
        # y: [60k], [10k]
        x_train, y_train, x_test, y_test = data_processing.read_data()
        # print(y_test.shape,x_train.shape)
       
    
        model=tf.keras.models.Sequential([
            # 这里输入层还是要写单个输入的shape
            tf.keras.layers.Conv2D(input_shape=(28,28,1),filters=32,
                                   kernel_size=(3,3),strides=(1,1),padding='SAME',activation='relu'),
            tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2),padding='SAME'),
            
            tf.keras.layers.Conv2D(filters=64,kernel_size=(3,3),
                                   strides=(1,1),padding='SAME',activation='relu'),
            tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2),padding='SAME'),
            tf.keras.layers.Dropout(0.7),
            
            tf.keras.layers.Flatten(),
            
            # FC1
            tf.keras.layers.Dense(128,activation='relu'),
            tf.keras.layers.Dropout(0.5),
            # FC2|output
            tf.keras.layers.Dense(10,activation='softmax'),
        ])
        # 查看层的信息
        # print(model.summary())
    	
        # 设置训练参数
        model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
        
        # 训练(你甚至都不需要自己转onehot)
        # validation_split=x 将训练集*x变为测试集,进行预测
        # verbose=1 显示训练信息
        model.fit(x=x_train,y=y_train,batch_size=32,epochs=5,validation_split=0.3,verbose=1)
        train_loss,train_accu=model.evaluate(x=x_test,y=y_test)
        print(train_loss)
        print(train_accu)
    

    这个才训练到 98.5%好垃圾

    model save|restore

    参考这个

    有两种方式save

    只保存weight和bias,不保存网络结构

    这个知道就好了 其实是我懒得写,可以看那个链接里面写的

    保存网络结构

    import tensorflow as tf
    
    # 这个model是前面的那个model类
    model.save("path")
    # model del
    	# 这里的测试可以自己输入
        x_train,y_train,x_test,y_test=data_processing.read_data()
        restore_model= tf.keras.models.load_model('./my_model.ckpt')
        loss,acc=restore_model.evaluate(x_test,y_test)
        print(loss)
        print(acc)
    

    predict

    	# draw 当然自己随便写,预测数据还是得本地导入
        draw(x_test.numpy()[rad].reshape(28,28),y_test.numpy()[rad])
        restore_model= tf.keras.models.load_model('./my_model.ckpt')
        
        pro=np.argmax(restore_model.predict(x_test.numpy()[rad].reshape(1,28,28,1)))
        print('???',pro)
    
  • 相关阅读:
    問題集リンク(DEV I)
    認定Platformデベロッパー 試験範囲
    React 学习资源
    IIS
    小学校
    リストに項番をつける
    七、JavaScript函数
    六、JavaScript数组
    五、JavaScript流程控制
    四、JavaScript操作符
  • 原文地址:https://www.cnblogs.com/cherrypill/p/13289538.html
Copyright © 2011-2022 走看看