今天通过观看老师分享的视频教程,学习了mnist的使用方法以及对数据集的操作:
代码如下:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/',one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print("trainlabel:",type(trainlabel),"shape:",trainlabel.shape)
print("trainimg:",type(trainimg),"shape:",trainimg.shape)
print("testlabel:",type(testlabel),"shape:",testlabel.shape)
print("testimg:",type(testimg),"shape:",testimg.shape)
nsample = 5
randidx = np.random.randint(trainimg.shape[0],size=nsample)
for i in randidx:
cur_img = np.reshape(trainimg[i,:],(28,28))
cur_label = np.argmax(trainlabel[i,:])
plt.matshow(cur_img,cmap=plt.get_cmap('gray'))
plt.title("" + str(i) + "th Training Data " + "Label is " + str(cur_label))
print("" + str(i) + "th Training Data " + "Label is " + str(cur_label))
plt.show()
batch_size = 100
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
print("batch_xs:",type(batch_xs),"shape:",batch_xs.shape)
print("batch_ys:",type(batch_ys),"shape:",batch_ys.shape)
输出结果截图: