Mnist数据集简介3
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import input_data print("packs loaded") print("Download and Extract MNIST dataset") mnist = input_data.read_data_sets('data/',one_hot=True) #one_hot=True编码格式为01编码 print print("type of 'mnist' is %s" % (type(mnist))) print("number of train data is %d" % (mnist.train.num_examples)) print("number of test data is %d" % (mnist.test.num_examples)) trainimg = mnist.train.images trainlabel = mnist.train.labels testimg = mnist.test.images testlabel = mnist.test.labels #初步看一下数据集的样子 nsample = 5 randidx = np.random.randint(trainimg.shape[0],size=nsample) for i in randidx: curr_img = np.reshape(trainimg[i,:],(28,28)) curr_label = np.argmax(trainlabel[i,:]) plt.matshow(curr_img,cmap=plt.get_cmap('gray')) plt.show() #分批学习 batch_size = 100 batch_xs, batch_ys = mnist.train.next_batch(batch_size) print("shape of 'batch_xs' is %s" % (batch_xs.shape,)) print("shape of 'batch_ys' is %s" % (batch_ys.shape,))