zoukankan      html  css  js  c++  java
  • TensorFlow 下 mnist 数据集的操作及可视化

    from tensorflow.examples.tutorials.mnist import input_data

    首先需要连网下载数据集:

    mnsit = input_data.read_data_sets(train_dir='./MNIST_DATA', one_hot=True)
        # 如果当前文件夹下没有 MNIST_DATA,会首先创建该文件夹,然后下载 mnist 数据集

    训练集与测试集的划分:

    X_train, y_train = mnist.train.images, mnist.train.labels
            # 返回的 X_train 是 numpy 下的 多维数组,(55000, 784)
    X_test, y_test = mnist.test.images, mnist.test.labels
            # (10000, 784)
    X_valid, y_valid = mnist.valid.images, mnist.valid.labels
            # (5000, 784)

    当然可以通过迭代的形式以一定 batch_size 读取数据:

    mnist.train.next_batch(100)
    • mnist.train.next_batch() ⇒ 返回两个值,一个是图像数据,一个是图像数据对应的类别信息。

      >> X_batch, y_batch = mnist.train.next_batch(100)
      >> X_batch.shape
      (100, 784)
      >> y_batch.shape
      (100, 10)                 # one hot 编码

    1. 可视化

    # images:9*(28*28) 的 numpy.ndarray
    # y_ 表示其真实的标签信息
    def plot_mnist_3_3(images, y_, y=None):
        assert images.shape[0] == len(y_)
        fig, axes = plt.subplots(3, 3)
        for i, ax in enumerate(axes.flat):
            ax.imshow(images[i].reshape(image_shp), cmap='binary')
            if y is None:
                xlabel = 'True: {}'.format(y_[i])
            else:
                xlabel = 'True: {0}, Pred: {1}'.format(y_[i], y[i])
            ax.set_xlabel(xlabel)
            ax.set_xticks([])
            ax.set_yticks([])
        plt.show()
  • 相关阅读:
    Android 工程师进阶 34 讲
    300分钟搞定数据结构与算法
    即学即用的Spark实战44讲
    42讲轻松通关 Flink
    Webpack原理与实践
    大数据运维实战
    ZooKeeper源码分析与实战
    前端高手进阶
    重学数据结构与算法
    ElementUI中el-upload怎样上传文件并且传递额外参数给Springboot后台进行接收
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9421990.html
Copyright © 2011-2022 走看看