zoukankan      html  css  js  c++  java
  • MNIST手写字母识别(一)

    这是个分类应用入门:使用softmax分类,简单来说就是概论转化为0-1区间的一个数字

    读取数据集

    1 # 导入相关库
    2 import tensorflow as tf
    3 from tensorflow.examples.tutorials.mnist import input_data
    4 mnist=input_data.read_data_sets("D:/MNIST",one_hot=True)

     独热编码(one hot encoding)

    一种稀疏向量,其中:一个元素设为1,所有其他元素均设为0

    独热编码常用于表示拥有有限个可能值的字符串或标识符

    工作流程:

      1 import tensorflow as tf
      2 import matplotlib.pyplot as plt
      3 import numpy as np
      4 import tensorflow.examples.tutorials.mnist.input_data as input_data  
      5 mnist=input_data.read_data_sets("MNIST_data",one_hot=True)  #读取数据
      6 import os  #可加可不加,屏蔽通知消息
      7 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
      8 
      9 print('训练集 train 数量:',mnist.train.num_examples,
     10       ',验证集 validation 数量:',mnist.validation.num_examples,
     11       ',测试集 test 数量:',mnist.test.num_examples)
     12 
     13 # print('train images shape:',mnist.train.images.shape,
     14 #       'labels shaple:',mnist.train.labels.shape)
     15 
     16 # print((mnist.train.images[0].reshape(28,28)))
     17 # print(len(mnist.train.images[0].shape))
     18 
     19 # def plot_image(image):
     20 #     plt.imshow(image.reshape(28,28))
     21 #     plt.show()
     22 #
     23 # plot_image(mnist.train.images[1])
     24 # plt.imshow(mnist.train.images[20000].reshape(14,56))
     25 # plt.show()
     26 
     27 # print(mnist.train.labels[1])
     28 # print(np.argmax(mnist.train.labels[1]))
     29 # mnist_no_one_hot=input_data.read_data_sets("MNIST_data",one_hot=False)
     30 # print(mnist_no_one_hot.train.labels[0:10])
     31 #
     32 # print('validation images:',mnist.validation.images.shape,'labels:',mnist.validation.labels.shape)
     33 #
     34 # print('test images:',mnist.test.images.shape,'labels:',mnist.test.labels.shape)
     35 #
     36 # batch_images_xs,batch_labels_ys=mnist.train.next_batch(batch_size=10)
     37 # print(mnist.train.labels[0:10])
     38 # # print(batch_labels_ys)
     39 
     40 # mnist中每张图片共有28*28=784个像素点
     41 x=tf.placeholder(tf.float32,[None,784])
     42 # 0-9一共10个数字->10个类别
     43 y=tf.placeholder(tf.float32,[None,10])
     44 
     45 # 定义模型变量(以正态分布的随机数初始化权重W,以常数0初始化偏置b)
     46 W=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
     47 b=tf.Variable(tf.zeros([10]))
     48 
     49 # 前向计算
     50 forward=tf.matmul(x,W)+b
     51 #softmax分类
     52 pred=tf.nn.softmax(forward)
     53 # 定义交叉熵损失函数
     54 loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
     55 
     56 # 设置训练参数
     57 train_epochs=150 # 训练轮数
     58 batch_size=50 # 单次训练样本数(批次大小)
     59 total_batch=int(mnist.train.num_examples/batch_size) # 一轮训练的批次数
     60 display_step=1 # 显示粒度
     61 learning_rate=0.04 # 学习率
     62 
     63 # 分类模型构建与训练实践
     64 #选择优化器,梯度下降
     65 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
     66 
     67 # 定义准确率,检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况
     68 correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
     69 # 准确率,将布尔值转化为浮点数,并计算平均值
     70 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
     71 
     72 # 声明会话
     73 sess=tf.Session()
     74 init=tf.global_variables_initializer()
     75 sess.run(init)
     76 
     77 # 训练模型
     78 for epoch in range(train_epochs):
     79       for batch in range(total_batch):
     80             xs, ys = mnist.train.next_batch(batch_size)  # 读取批次数据
     81             sess.run(optimizer, feed_dict={x: xs, y: ys})  # 执行批次训练
     82 
     83       # total_batch批次训练完成之后,使用验证数据计算误差与准确率,验证集没有分批。
     84       loss, acc = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
     85 
     86       # 打印训练过程中的详细信息
     87       if (epoch + 1)% display_step==0:
     88             print("train_epoch:", '%02d' % (epoch + 1), "loss=", "{:.9f}".format(loss),"accuracy=", '{:.4f}'.format(acc))
     89 print("train finished!")
     90 
     91 # 在测试集上评估模型准确率
     92 accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
     93 print("test accuracy:",accu_test)
     94 
     95 # 在验证集上评估模型准确率
     96 accu_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
     97 print("validatin accuracy:",accu_validation)
     98 
     99 # 在训练集上评估模型准确率
    100 accu_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
    101 print("tarin accuracy:",accu_train)
    102 
    103 # 由于pred预测结果是one-hot编码格式,所以需要转化为0~9数字
    104 prediction_result=sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
    105 
    106 # 查看结果中的前十项
    107 prediction_result[0:10]
    108 # 定义可视化函数
    109 
    110 def plt_images_labels_prediction(images,  # 图像列表
    111                                   labels,  # 标签列表
    112                                   prediction,  # 预测值列表
    113                                   index,  # 从第index个开始显示
    114                                   num=10):  # 缺省依次显示10副
    115       fig = plt.gcf()  # 获取当前图表,get current figure
    116       fig.set_size_inches(10, 12)  # 1英寸等于2.54cm
    117       if num > 25:
    118             num = 25  # 最多显示25个子图
    119       for i in range(0, num):
    120             ax = plt.subplot(5, 5, i + 1)  # 获取当前要处理的子图
    121 
    122             ax.imshow(np.reshape(images[index], (28, 28)),
    123                       cmap='binary')  # 显示第index个图像
    124             title = "labels=" + str(np.argmax(labels[index]))  # 构建该图上要显示的title信息
    125             if len(prediction) > 0:
    126                   title += ",predict=" + str(prediction[index])
    127 
    128             ax.set_title(title)  # 显示图上的title
    129             ax.set_xticks([])  # 不显示坐标轴
    130             ax.set_yticks([])
    131             index += 1
    132       plt.show()
    133 # 可视化预测结果
    134 plt_images_labels_prediction(mnist.test.images,
    135                              mnist.test.labels,
    136                              prediction_result,10,25)

    大概就结束了,相当于机器学习的一个helloworld

  • 相关阅读:
    hdu 2485 Destroying the bus stations 迭代加深搜索
    hdu 2487 Ugly Windows 模拟
    hdu 2492 Ping pong 线段树
    hdu 1059 Dividing 多重背包
    hdu 3315 My Brute 费用流,费用最小且代价最小
    第四天 下载网络图片显示
    第三天 单元测试和数据库操作
    第二天 布局文件
    第一天 安卓简介
    Android 获取存储空间
  • 原文地址:https://www.cnblogs.com/hly97/p/12853115.html
Copyright © 2011-2022 走看看