一、读取MNIST数据
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data", one_hot = True)
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
import matplotlib.pyplot as plt
n_samples = 5
plt.figure(figsize=(n_samples * 2, 3))
for index in range(n_samples):
plt.subplot(1, n_samples, index + 1)
sample_image = mnist.train.images[index].reshape(28, 28)
plt.imshow(sample_image, cmap="binary")
plt.axis("off")
plt.show()