import sys, os
sys.path.append('F:mlDLsource-code') #导入此路径中
from dataset.mnist import load_mnist
from PIL import Image
import numpy as np
(x_train, t_train), (x_test, t_test) = load_mnist(flatten = True, normalize = False, one_hot_label = False)
#flatten参数为True的含义是展开输入图像(变成784个元素构成的一维数组)如果设置成False则为1*28*28的三维数组。 normalize为True时表示将输入图像正规化为0.0~0.1的值,为False时会保持原来的像素0~255.one_hot_label设置是否将标签保存为onehot表示(one-hot representation)。 one-hot表示是仅正确解标签为1,其余 皆为0的数组,就像[0,0,1,0,0,0,0,0,0,0]这样。当one_hot_label为False时, 只是像7、2这样简单保存正确解标签;当one_hot_label为True时,标签则 保存为one-hot表示。
def img_show(img):
pil_img = Image.fromarray(np.uint8(img)) #将numpy数组的形状保存的图像转化为PIL用的数据对象。
pil_img.show()
img = x_train[0]
label = t_train[0]
print(label)
print(img.shape)
img = img.reshape(28, 28)
print(img.shape)
img_show(img)
5
(784,)
(28, 28)
![](https://img2018.cnblogs.com/blog/1313448/201909/1313448-20190916155702137-1429663275.png)