zoukankan      html  css  js  c++  java
  • tensorflow入门学习及MNIST手写数字识别学习

    tensorflow入门学习及MNIST手写数字识别学习

    1. tensorflow安装

    如果使用pycharm设置来安装会报下载超时的错误
    使用命令安装也需要下载很久
    
    //使用这个命令安装就很快
    pip install tensorflow==2.0.0rc1 -i https://pypi.tuna.tsinghua.edu.cn/simple
    
    

    2. 安装tensorflow结果查看

    import tensorflow as tf
    # 查看安装版本
    print(tf.__version__)
    # 打印所有数据集
    print(dir(tf.keras.datasets))
    
    #打印结果
    2.0.0-rc1
    ['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_sys', 'boston_housing', 'cifar10', 'cifar100', 'fashion_mnist', 'imdb', 'mnist', 'reuters']
    
    

    3. MNIST数据集可视化

    # 加载数据集
    mnist = tf.keras.datasets.mnist
    # 包括两个数据集,一个是训练数据(60000个),一个是测试数据(10000个)
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    # 打印数据格式
    print(x_train.shape, y_train.shape)
    print(x_test.shape, y_test.shape)
    
    #打印结果
    (60000, 28, 28) (60000,)
    (10000, 28, 28) (10000,)
    
    # 显示图片
    # 导入依赖库
    import matplotlib.pyplot as plt
    # 随机选取一个数并查看label
    image_index = 15000  # 范围是[0,59999)
    # 图片显示(彩色)
    # plt.imshow(x_train[image_index])
    # 图片灰度显示(黑白)
    plt.imshow(x_train[image_index], cmap='Greys')
    plt.show()
    # 打印图片所代表的数字
    print(y_train[image_index])
    

    4. MNIST数据集格式转换

    import numpy as np
    # 将图片从28*28扩充为32*32
    x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant', constant_values=0)
    print(x_train.shape)
    # 数据类型转换
    x_train = x_train.astype('float32')
    # 数据正则化
    x_train /= 255
    # 数据维度转换([n,h,w,c])
    x_train = x_train.reshape(x_train.shape[0], 32, 32, 1)
    print(x_train.shape)
    
    # 打印结果
    (60000, 32, 32)
    (60000, 32, 32, 1)
    

    5. 构建LeNet模型

    # 构建LeNet模型
    model = tf.keras.models.Sequential([
        # 第一层:卷积层编码
        # 参数:卷积核个数,卷积核大小,填充方式,激活函数,输入数据格式
        tf.keras.layers.Conv2D(filters=6, kernel_size=(5, 5), padding='valid', activation=tf.nn.relu,
                               input_shape=(32, 32, 1)),
        # 第二层:池化层
        tf.keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='same'),
        # 第三层:卷积层
        tf.keras.layers.Conv2D(filters=16, kernel_size=(5, 5), padding='valid', activation=tf.nn.relu),
        # 第四层:池化层
        tf.keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='same'),
        # 扁平处理(多维数据转换为一维数据)
        tf.keras.layers.Flatten(),
    
        # 第五、六、七层:全连接层
        tf.keras.layers.Dense(units=120, activation=tf.nn.relu),
        tf.keras.layers.Dense(units=84, activation=tf.nn.relu),
        tf.keras.layers.Dense(units=10, activation=tf.nn.softmax)
    ])
    
    print(model.summary())
    
    # 模型训练
    
    # 超参数设置
    num_epochs = 10
    batch_size = 64
    learning_rate = 0.001
    adam_optimizer = tf.keras.optimizers.Adam(learning_rate)
    
    model.compile(optimizer=adam_optimizer,
                  loss=tf.keras.losses.sparse_categorical_crossentropy,
                  metrics=['accuracy'])
    
    # 用于计算学习时间
    import datetime
    
    start_time = datetime.datetime.now()
    
    model.fit(x=x_train,
              y=y_train,
              batch_size=batch_size,
              epochs=num_epochs)
    end_time = datetime.datetime.now()
    time_cost = end_time - start_time
    print("time_cost = ", time_cost)  # CPU time cost: 5min, GPU time cost: less than 1min
    
    # 模型保存(保存在当前目录下)
    model.save('lenet_model.h5')
    

    6. 使用模型来识别手写的数字

    # 使用模型识别图片数字
    import tensorflow as tf
    
    # 加载已经训练好的LeNet模型
    model = tf.keras.models.load_model('lenet_model.h5')
    model.summary()
    
    
    import cv2
    import matplotlib.pyplot as plt
    
    # 第一步:读取图片
    img = cv2.imread('numberImages/9.jpg')  # 8.png
    print(img.shape)
    plt.imshow(img)
    plt.show()
    
    # 第二步:将图片转为灰度图
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    print(img.shape)
    plt.imshow(img, cmap='Greys')
    plt.show()
    
    # 第三步:将图片的底色和字的颜色取反
    img = cv2.bitwise_not(img)
    plt.imshow(img, cmap='Greys')
    plt.show()
    
    # 第四步:将底变成纯白色,将字变成纯黑色
    img[img <= 100] = 0
    img[img > 140] = 255  # 130
    plt.imshow(img)
    plt.show()
    
    # 显示图片
    plt.imshow(img, cmap='Greys')
    plt.show()
    
    # 第五步:将图片尺寸缩放为输入规定尺寸
    img = cv2.resize(img, (32, 32))
    plt.show()
    
    # 第六步:将数据类型转为float32
    img = img.astype('float32')
    
    # 第七步:数据正则化
    img /= 255
    
    # 第八步:增加维度为输入的规定格式
    img = img.reshape(1, 32, 32, 1)
    print(img.shape)
    
    # 第九步:预测
    pred = model.predict(img)
    
    # 第十步:输出结果
    print('=========================')
    print('预测结果为:', pred.argmax())
    
    

    9.jpg

    # 运行结果
    Model: "sequential"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    conv2d (Conv2D)              (None, 28, 28, 6)         156       
    _________________________________________________________________
    average_pooling2d (AveragePo (None, 14, 14, 6)         0         
    _________________________________________________________________
    conv2d_1 (Conv2D)            (None, 10, 10, 16)        2416      
    _________________________________________________________________
    average_pooling2d_1 (Average (None, 5, 5, 16)          0         
    _________________________________________________________________
    flatten (Flatten)            (None, 400)               0         
    _________________________________________________________________
    dense (Dense)                (None, 120)               48120     
    _________________________________________________________________
    dense_1 (Dense)              (None, 84)                10164     
    _________________________________________________________________
    dense_2 (Dense)              (None, 10)                850       
    =================================================================
    Total params: 61,706
    Trainable params: 61,706
    Non-trainable params: 0
    _________________________________________________________________
    (651, 650, 3)
    (651, 650)
    (1, 32, 32, 1)
    =========================
    预测结果为: 9
    

    LetNet模型简介








    资料:

    链接:https://pan.baidu.com/s/1kJz0IvNIMpYU3-ZLwM20cQ 
    提取码:ty40
    版权声明:本文为博主原创文章,转载请附上博文链接!
  • 相关阅读:
    创建zull工程时pom文件报错failed to read artifact descriptor for org.springframework.cloud:spring-cloud
    利用eureka构建一个简单的springCloud分布式集群
    《信息安全专业导论》第十一周学习总结
    Nmap
    Excel数据统计与分析
    python模拟进程状态
    《信息安全专业导论》第9周学习总结
    俄罗斯方块
    《信息安全专业导论》第八周学习总结
    熟悉编程语言
  • 原文地址:https://www.cnblogs.com/zq98/p/13654036.html
Copyright © 2011-2022 走看看