这是个分类应用入门:使用softmax分类,简单来说就是概论转化为0-1区间的一个数字
读取数据集
1 # 导入相关库 2 import tensorflow as tf 3 from tensorflow.examples.tutorials.mnist import input_data 4 mnist=input_data.read_data_sets("D:/MNIST",one_hot=True)
独热编码(one hot encoding)
一种稀疏向量,其中:一个元素设为1,所有其他元素均设为0
独热编码常用于表示拥有有限个可能值的字符串或标识符
工作流程:
1 import tensorflow as tf 2 import matplotlib.pyplot as plt 3 import numpy as np 4 import tensorflow.examples.tutorials.mnist.input_data as input_data 5 mnist=input_data.read_data_sets("MNIST_data",one_hot=True) #读取数据 6 import os #可加可不加,屏蔽通知消息 7 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 8 9 print('训练集 train 数量:',mnist.train.num_examples, 10 ',验证集 validation 数量:',mnist.validation.num_examples, 11 ',测试集 test 数量:',mnist.test.num_examples) 12 13 # print('train images shape:',mnist.train.images.shape, 14 # 'labels shaple:',mnist.train.labels.shape) 15 16 # print((mnist.train.images[0].reshape(28,28))) 17 # print(len(mnist.train.images[0].shape)) 18 19 # def plot_image(image): 20 # plt.imshow(image.reshape(28,28)) 21 # plt.show() 22 # 23 # plot_image(mnist.train.images[1]) 24 # plt.imshow(mnist.train.images[20000].reshape(14,56)) 25 # plt.show() 26 27 # print(mnist.train.labels[1]) 28 # print(np.argmax(mnist.train.labels[1])) 29 # mnist_no_one_hot=input_data.read_data_sets("MNIST_data",one_hot=False) 30 # print(mnist_no_one_hot.train.labels[0:10]) 31 # 32 # print('validation images:',mnist.validation.images.shape,'labels:',mnist.validation.labels.shape) 33 # 34 # print('test images:',mnist.test.images.shape,'labels:',mnist.test.labels.shape) 35 # 36 # batch_images_xs,batch_labels_ys=mnist.train.next_batch(batch_size=10) 37 # print(mnist.train.labels[0:10]) 38 # # print(batch_labels_ys) 39 40 # mnist中每张图片共有28*28=784个像素点 41 x=tf.placeholder(tf.float32,[None,784]) 42 # 0-9一共10个数字->10个类别 43 y=tf.placeholder(tf.float32,[None,10]) 44 45 # 定义模型变量(以正态分布的随机数初始化权重W,以常数0初始化偏置b) 46 W=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0)) 47 b=tf.Variable(tf.zeros([10])) 48 49 # 前向计算 50 forward=tf.matmul(x,W)+b 51 #softmax分类 52 pred=tf.nn.softmax(forward) 53 # 定义交叉熵损失函数 54 loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 55 56 # 设置训练参数 57 train_epochs=150 # 训练轮数 58 batch_size=50 # 单次训练样本数(批次大小) 59 total_batch=int(mnist.train.num_examples/batch_size) # 一轮训练的批次数 60 display_step=1 # 显示粒度 61 learning_rate=0.04 # 学习率 62 63 # 分类模型构建与训练实践 64 #选择优化器,梯度下降 65 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) 66 67 # 定义准确率,检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况 68 correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 69 # 准确率,将布尔值转化为浮点数,并计算平均值 70 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 71 72 # 声明会话 73 sess=tf.Session() 74 init=tf.global_variables_initializer() 75 sess.run(init) 76 77 # 训练模型 78 for epoch in range(train_epochs): 79 for batch in range(total_batch): 80 xs, ys = mnist.train.next_batch(batch_size) # 读取批次数据 81 sess.run(optimizer, feed_dict={x: xs, y: ys}) # 执行批次训练 82 83 # total_batch批次训练完成之后,使用验证数据计算误差与准确率,验证集没有分批。 84 loss, acc = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images, y: mnist.validation.labels}) 85 86 # 打印训练过程中的详细信息 87 if (epoch + 1)% display_step==0: 88 print("train_epoch:", '%02d' % (epoch + 1), "loss=", "{:.9f}".format(loss),"accuracy=", '{:.4f}'.format(acc)) 89 print("train finished!") 90 91 # 在测试集上评估模型准确率 92 accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) 93 print("test accuracy:",accu_test) 94 95 # 在验证集上评估模型准确率 96 accu_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels}) 97 print("validatin accuracy:",accu_validation) 98 99 # 在训练集上评估模型准确率 100 accu_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels}) 101 print("tarin accuracy:",accu_train) 102 103 # 由于pred预测结果是one-hot编码格式,所以需要转化为0~9数字 104 prediction_result=sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) 105 106 # 查看结果中的前十项 107 prediction_result[0:10] 108 # 定义可视化函数 109 110 def plt_images_labels_prediction(images, # 图像列表 111 labels, # 标签列表 112 prediction, # 预测值列表 113 index, # 从第index个开始显示 114 num=10): # 缺省依次显示10副 115 fig = plt.gcf() # 获取当前图表,get current figure 116 fig.set_size_inches(10, 12) # 1英寸等于2.54cm 117 if num > 25: 118 num = 25 # 最多显示25个子图 119 for i in range(0, num): 120 ax = plt.subplot(5, 5, i + 1) # 获取当前要处理的子图 121 122 ax.imshow(np.reshape(images[index], (28, 28)), 123 cmap='binary') # 显示第index个图像 124 title = "labels=" + str(np.argmax(labels[index])) # 构建该图上要显示的title信息 125 if len(prediction) > 0: 126 title += ",predict=" + str(prediction[index]) 127 128 ax.set_title(title) # 显示图上的title 129 ax.set_xticks([]) # 不显示坐标轴 130 ax.set_yticks([]) 131 index += 1 132 plt.show() 133 # 可视化预测结果 134 plt_images_labels_prediction(mnist.test.images, 135 mnist.test.labels, 136 prediction_result,10,25)
大概就结束了,相当于机器学习的一个helloworld