Mnist数据集简介2
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import input_data print("packs loaded") print("Download and Extract MNIST dataset") mnist = input_data.read_data_sets('data/',one_hot=True) #one_hot=True编码格式为01编码 print print("type of 'mnist' is %s" % (type(mnist))) print("number of train data is %d" % (mnist.train.num_examples)) print("number of test data is %d" % (mnist.test.num_examples)) trainimg = mnist.train.images trainlabel = mnist.train.labels testimg = mnist.test.images testlabel = mnist.test.labels #初步看一下数据集的样子 nsample = 5 randidx = np.random.randint(trainimg.shape[0],size=nsample) for i in randidx: curr_img = np.reshape(trainimg[i,:],(28,28)) curr_label = np.argmax(trainlabel[i,:]) plt.matshow(curr_img,cmap=plt.get_cmap('gray')) plt.show()