zoukankan      html  css  js  c++  java
  • 导出MNIST的数据集

    在TensorFlow的官方入门课程中,多次用到mnist数据集。

    mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx3-ubyte的二进制文件。

    如果我们想要知道大名鼎鼎的mnist手写体数字都长什么样子,就需要从mnist数据集中导出手写体数字图片。了解这些手写体的总体形状,也有助于加深我们对TensorFlow入门课程的理解。

    下面先给出通过TensorFlow api接口导出mnist手写体数字图片的python代码,再对代码进行分析。代码在win7下测试通过,linux环境也可以参考本处代码。

    (非常良心的注释和打印有木有)

    [python] view plain copy
     
    1. #!/usr/bin/python3.5  
    2. # -*- coding: utf-8 -*-  
    3.   
    4. import os  
    5. import tensorflow as tf  
    6. from tensorflow.examples.tutorials.mnist import input_data  
    7.   
    8. from PIL import Image  
    9.   
    10. # 声明图片宽高  
    11. rows = 28  
    12. cols = 28  
    13.   
    14. # 要提取的图片数量  
    15. images_to_extract = 8000  
    16.   
    17. # 当前路径下的保存目录  
    18. save_dir = "./mnist_digits_images"  
    19.   
    20. # 读入mnist数据  
    21. mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)  
    22.   
    23. # 创建会话  
    24. sess = tf.Session()  
    25.   
    26. # 获取图片总数  
    27. shape = sess.run(tf.shape(mnist.train.images))  
    28. images_count = shape[0]  
    29. pixels_per_image = shape[1]  
    30.   
    31. # 获取标签总数  
    32. shape = sess.run(tf.shape(mnist.train.labels))  
    33. labels_count = shape[0]  
    34.   
    35. # mnist.train.labels是一个二维张量,为便于后续生成数字图片目录名,有必要一维化(后来发现只要把数据集的one_hot属性设为False,mnist.train.labels本身就是一维)  
    36. #labels = sess.run(tf.argmax(mnist.train.labels, 1))  
    37. labels = mnist.train.labels  
    38.   
    39. # 检查数据集是否符合预期格式  
    40. if (images_count == labels_count) and (shape.size == 1):  
    41.     print ("数据集总共包含 %s 张图片,和 %s 个标签" % (images_count, labels_count))  
    42.     print ("每张图片包含 %s 个像素" % (pixels_per_image))  
    43.     print ("数据类型:%s" % (mnist.train.images.dtype))  
    44.   
    45.     # mnist图像数据的数值范围是[0,1],需要扩展到[0,255],以便于人眼观看  
    46.     if mnist.train.images.dtype == "float32":  
    47.         print ("准备将数据类型从[0,1]转为binary[0,255]...")  
    48.         for i in range(0,images_to_extract):  
    49.             for n in range(pixels_per_image):  
    50.                 if mnist.train.images[i][n] != 0:  
    51.                     mnist.train.images[i][n] = 255  
    52.             # 由于数据集图片数量庞大,转换可能要花不少时间,有必要打印转换进度  
    53.             if ((i+1)%50) == 0:  
    54.                 print ("图像浮点数值扩展进度:已转换 %s 张,共需转换 %s 张" % (i+1, images_to_extract))  
    55.   
    56.     # 创建数字图片的保存目录  
    57.     for i in range(10):  
    58.         dir = "%s/%s/" % (save_dir,i)  
    59.         if not os.path.exists(dir):  
    60.             print ("目录 ""%s"" 不存在!自动创建该目录..." % dir)  
    61.             os.makedirs(dir)  
    62.   
    63.     # 通过python图片处理库,生成图片  
    64.     indices = [for x in range(0, 10)]  
    65.     for i in range(0,images_to_extract):  
    66.         img = Image.new("L",(cols,rows))  
    67.         for m in range(rows):  
    68.             for n in range(cols):  
    69.                 img.putpixel((n,m), int(mnist.train.images[i][n+m*cols]))  
    70.         # 根据图片所代表的数字label生成对应的保存路径  
    71.         digit = labels[i]  
    72.         path = "%s/%s/%s.bmp" % (save_dir, labels[i], indices[digit])  
    73.         indices[digit] += 1  
    74.         img.save(path)  
    75.         # 由于数据集图片数量庞大,保存过程可能要花不少时间,有必要打印保存进度  
    76.         if ((i+1)%50) == 0:  
    77.             print ("图片保存进度:已保存 %s 张,共需保存 %s 张" % (i+1, images_to_extract))  
    78.       
    79. else:  
    80.     print ("图片数量和标签数量不一致!")  

    上述代码的实现思路如下:

    1.读入mnist手写体数据;

    2.把数据的值从[0,1]浮点范围转化为黑白格式(背景为0-黑色,前景为255-白色);

    3.根据mnist.train.labels的内容,生成数字索引,也就是建立每一张图片和其所代表数字的关联,由此创建对应的保存目录;

    4.循环遍历mnist.train.images,把每张图片的像素数据赋值给python图片处理库PIL的Image类实例,再调用Image类的save方法把图片保存在第3步骤中创建的对应目录。

    在运行上述代码之前,你需要确保本地已经安装python的图片处理库PIL,pip安装命令如下:

    pip3 install Pillow

    或 pip install Pillow,取决于你的pip版本。

    上述python代码运行后,在当前目录下会生成mnist_digits_images目录,在该目录下,可以看到如下内容:

    可以看到,我们成功地生成了黑底白字的数字图片。

    如果仔细观察这些图片,会看到一些肉眼也难以分辨的数字,譬如:

    上面这几个数字是2。想不到吧?

    下面这两个是5(看起来更像6):

    这个是7:(7长这样?有句MMP不知当讲不当讲)

    猜猜下面这个是什么:

    这是大写的L?不是。

    有点像1,是1吗?也不是。

    倒立拉粑的7?sorry,又猜错了。

    实话告诉您,它是2!一开始我也是不相信的,知道真相的那一刻我下巴差点掉下来!

    这些手写图片,一般人用肉眼观察,识别率能达到98%就不错了,但是通过TensorFlow搭建的卷积神经网络识别率可以达到99%,非常地神奇!

  • 相关阅读:
    charles修改响应体
    charles重发网络请求&模拟慢速网络&过滤网络请求
    charles修改请求体内容
    monkeyrunner环境搭建以及实例(转)
    django模型中的抽象类(abstract)
    Linux启动/停止/重启Mysql数据库的方法
    ava.net.SocketException: Unrecognized Windows Sockets error: 0: JVM_Bind (解决思路)
    unix PS命令和JPS命令的区别
    mysql:表注释和字段注释
    mysql-关于Unix时间戳(unix_timestamp)
  • 原文地址:https://www.cnblogs.com/liuys635/p/11187921.html
Copyright © 2011-2022 走看看