zoukankan      html  css  js  c++  java
  • MNIST手写字母识别(一)

    这是个分类应用入门:使用softmax分类,简单来说就是概论转化为0-1区间的一个数字

    读取数据集

    1 # 导入相关库
    2 import tensorflow as tf
    3 from tensorflow.examples.tutorials.mnist import input_data
    4 mnist=input_data.read_data_sets("D:/MNIST",one_hot=True)

     独热编码(one hot encoding)

    一种稀疏向量,其中:一个元素设为1,所有其他元素均设为0

    独热编码常用于表示拥有有限个可能值的字符串或标识符

    工作流程:

      1 import tensorflow as tf
      2 import matplotlib.pyplot as plt
      3 import numpy as np
      4 import tensorflow.examples.tutorials.mnist.input_data as input_data  
      5 mnist=input_data.read_data_sets("MNIST_data",one_hot=True)  #读取数据
      6 import os  #可加可不加,屏蔽通知消息
      7 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
      8 
      9 print('训练集 train 数量:',mnist.train.num_examples,
     10       ',验证集 validation 数量:',mnist.validation.num_examples,
     11       ',测试集 test 数量:',mnist.test.num_examples)
     12 
     13 # print('train images shape:',mnist.train.images.shape,
     14 #       'labels shaple:',mnist.train.labels.shape)
     15 
     16 # print((mnist.train.images[0].reshape(28,28)))
     17 # print(len(mnist.train.images[0].shape))
     18 
     19 # def plot_image(image):
     20 #     plt.imshow(image.reshape(28,28))
     21 #     plt.show()
     22 #
     23 # plot_image(mnist.train.images[1])
     24 # plt.imshow(mnist.train.images[20000].reshape(14,56))
     25 # plt.show()
     26 
     27 # print(mnist.train.labels[1])
     28 # print(np.argmax(mnist.train.labels[1]))
     29 # mnist_no_one_hot=input_data.read_data_sets("MNIST_data",one_hot=False)
     30 # print(mnist_no_one_hot.train.labels[0:10])
     31 #
     32 # print('validation images:',mnist.validation.images.shape,'labels:',mnist.validation.labels.shape)
     33 #
     34 # print('test images:',mnist.test.images.shape,'labels:',mnist.test.labels.shape)
     35 #
     36 # batch_images_xs,batch_labels_ys=mnist.train.next_batch(batch_size=10)
     37 # print(mnist.train.labels[0:10])
     38 # # print(batch_labels_ys)
     39 
     40 # mnist中每张图片共有28*28=784个像素点
     41 x=tf.placeholder(tf.float32,[None,784])
     42 # 0-9一共10个数字->10个类别
     43 y=tf.placeholder(tf.float32,[None,10])
     44 
     45 # 定义模型变量(以正态分布的随机数初始化权重W,以常数0初始化偏置b)
     46 W=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
     47 b=tf.Variable(tf.zeros([10]))
     48 
     49 # 前向计算
     50 forward=tf.matmul(x,W)+b
     51 #softmax分类
     52 pred=tf.nn.softmax(forward)
     53 # 定义交叉熵损失函数
     54 loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
     55 
     56 # 设置训练参数
     57 train_epochs=150 # 训练轮数
     58 batch_size=50 # 单次训练样本数(批次大小)
     59 total_batch=int(mnist.train.num_examples/batch_size) # 一轮训练的批次数
     60 display_step=1 # 显示粒度
     61 learning_rate=0.04 # 学习率
     62 
     63 # 分类模型构建与训练实践
     64 #选择优化器,梯度下降
     65 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
     66 
     67 # 定义准确率,检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况
     68 correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
     69 # 准确率,将布尔值转化为浮点数,并计算平均值
     70 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
     71 
     72 # 声明会话
     73 sess=tf.Session()
     74 init=tf.global_variables_initializer()
     75 sess.run(init)
     76 
     77 # 训练模型
     78 for epoch in range(train_epochs):
     79       for batch in range(total_batch):
     80             xs, ys = mnist.train.next_batch(batch_size)  # 读取批次数据
     81             sess.run(optimizer, feed_dict={x: xs, y: ys})  # 执行批次训练
     82 
     83       # total_batch批次训练完成之后,使用验证数据计算误差与准确率,验证集没有分批。
     84       loss, acc = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
     85 
     86       # 打印训练过程中的详细信息
     87       if (epoch + 1)% display_step==0:
     88             print("train_epoch:", '%02d' % (epoch + 1), "loss=", "{:.9f}".format(loss),"accuracy=", '{:.4f}'.format(acc))
     89 print("train finished!")
     90 
     91 # 在测试集上评估模型准确率
     92 accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
     93 print("test accuracy:",accu_test)
     94 
     95 # 在验证集上评估模型准确率
     96 accu_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
     97 print("validatin accuracy:",accu_validation)
     98 
     99 # 在训练集上评估模型准确率
    100 accu_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
    101 print("tarin accuracy:",accu_train)
    102 
    103 # 由于pred预测结果是one-hot编码格式,所以需要转化为0~9数字
    104 prediction_result=sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
    105 
    106 # 查看结果中的前十项
    107 prediction_result[0:10]
    108 # 定义可视化函数
    109 
    110 def plt_images_labels_prediction(images,  # 图像列表
    111                                   labels,  # 标签列表
    112                                   prediction,  # 预测值列表
    113                                   index,  # 从第index个开始显示
    114                                   num=10):  # 缺省依次显示10副
    115       fig = plt.gcf()  # 获取当前图表,get current figure
    116       fig.set_size_inches(10, 12)  # 1英寸等于2.54cm
    117       if num > 25:
    118             num = 25  # 最多显示25个子图
    119       for i in range(0, num):
    120             ax = plt.subplot(5, 5, i + 1)  # 获取当前要处理的子图
    121 
    122             ax.imshow(np.reshape(images[index], (28, 28)),
    123                       cmap='binary')  # 显示第index个图像
    124             title = "labels=" + str(np.argmax(labels[index]))  # 构建该图上要显示的title信息
    125             if len(prediction) > 0:
    126                   title += ",predict=" + str(prediction[index])
    127 
    128             ax.set_title(title)  # 显示图上的title
    129             ax.set_xticks([])  # 不显示坐标轴
    130             ax.set_yticks([])
    131             index += 1
    132       plt.show()
    133 # 可视化预测结果
    134 plt_images_labels_prediction(mnist.test.images,
    135                              mnist.test.labels,
    136                              prediction_result,10,25)

    大概就结束了,相当于机器学习的一个helloworld

  • 相关阅读:
    Ajax省市区无刷新单表联动查询
    Hadoop2.0、YARN技术大数据视频教程
    零基础DNET B/S开发软件工程师培训视频教程
    零基础DNET CS开发视频教程
    HTML5开发框架PhoneGap实战视频教程
    Web前端开发视频教程
    FluentData 轻量级.NET ORM持久化技术详解
    前端 MVVM 框架KnockOut.JS深入浅出视频教程
    ASP.NET Web开发项目实战视频教程
    零基础到CS开发高手通用权限管理系统全程实录
  • 原文地址:https://www.cnblogs.com/hly97/p/12853115.html
Copyright © 2011-2022 走看看