zoukankan      html  css  js  c++  java
  • TFRecord 的使用

    什么是 TFRecord 

                PS:这段内容摘自 http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html

                一种保存记录的方法可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocolbuffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriterclass写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。
                从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocolbuffer)解析为张量。 MNIST的例子就使用了convert_to_records 所构建的数据。 请参看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py, 

    代码

                adjust_pic.py

                    单纯的转换图片大小

    [python] view plain copy
     
    1. # -*- coding: utf-8 -*-  
    2.   
    3. import tensorflow as tf  
    4.   
    5. def resize(img_data, width, high, method=0):  
    6.     return tf.image.resize_images(img_data,[width, high], method)  

                    pic2tfrecords.py

                    将图片保存成TFRecord

    [python] view plain copy
     
    1. # -*- coding: utf-8 -*-  
    2. # 将图片保存成 TFRecord  
    3. import os.path  
    4. import matplotlib.image as mpimg  
    5. import tensorflow as tf  
    6. import adjust_pic as ap  
    7. from PIL import Image  
    8.   
    9.   
    10. SAVE_PATH = 'data/dataset.tfrecords'  
    11.   
    12.   
    13. def _int64_feature(value):  
    14.     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  
    15.   
    16. def _bytes_feature(value):  
    17.     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  
    18.   
    19. def load_data(datafile, width, high, method=0, save=False):  
    20.     train_list = open(datafile,'r')  
    21.     # 准备一个 writer 用来写 TFRecord 文件  
    22.     writer = tf.python_io.TFRecordWriter(SAVE_PATH)  
    23.   
    24.     with tf.Session() as sess:  
    25.         for line in train_list:  
    26.             # 获得图片的路径和类型  
    27.             tmp = line.strip().split(' ')  
    28.             img_path = tmp[0]  
    29.             label = int(tmp[1])  
    30.   
    31.             # 读取图片  
    32.             image = tf.gfile.FastGFile(img_path, 'r').read()  
    33.             # 解码图片(如果是 png 格式就使用 decode_png)  
    34.             image = tf.image.decode_jpeg(image)  
    35.             # 转换数据类型  
    36.             # 因为为了将图片数据能够保存到 TFRecord 结构体中,所以需要将其图片矩阵转换成 string,所以为了在使用时能够转换回来,这里确定下数据格式为 tf.float32  
    37.             image = tf.image.convert_image_dtype(image, dtype=tf.float32)  
    38.             # 既然都将图片保存成 TFRecord 了,那就先把图片转换成希望的大小吧  
    39.             image = ap.resize(image, width, high)  
    40.             # 执行 op: image  
    41.             image = sess.run(image)  
    42.               
    43.             # 将其图片矩阵转换成 string  
    44.             image_raw = image.tostring()  
    45.             # 将数据整理成 TFRecord 需要的数据结构  
    46.             example = tf.train.Example(features=tf.train.Features(feature={  
    47.                 'image_raw': _bytes_feature(image_raw),  
    48.                 'label': _int64_feature(label),  
    49.                 }))  
    50.   
    51.             # 写 TFRecord  
    52.             writer.write(example.SerializeToString())  
    53.   
    54.     writer.close()  
    55.   
    56.   
    57. load_data('train_list.txt_bak', 224, 224)  

                    tfrecords2data.py

                    从TFRecord中读取并保存成图片

    [python] view plain copy
     
    1. # -*- coding: utf-8 -*-  
    2. # 从 TFRecord 中读取并保存图片  
    3. import tensorflow as tf  
    4. import numpy as np  
    5.   
    6.   
    7. SAVE_PATH = 'data/dataset.tfrecords'  
    8.   
    9.   
    10. def load_data(width, high):  
    11.     reader = tf.TFRecordReader()  
    12.     filename_queue = tf.train.string_input_producer([SAVE_PATH])  
    13.   
    14.     # 从 TFRecord 读取内容并保存到 serialized_example 中  
    15.     _, serialized_example = reader.read(filename_queue)  
    16.     # 读取 serialized_example 的格式  
    17.     features = tf.parse_single_example(  
    18.         serialized_example,  
    19.         features={  
    20.             'image_raw': tf.FixedLenFeature([], tf.string),  
    21.             'label': tf.FixedLenFeature([], tf.int64),  
    22.         })  
    23.   
    24.     # 解析从 serialized_example 读取到的内容  
    25.     images = tf.decode_raw(features['image_raw'], tf.uint8)  
    26.     labels = tf.cast(features['label'], tf.int64)  
    27.   
    28.     with tf.Session() as sess:  
    29.         # 启动多线程  
    30.         coord = tf.train.Coordinator()  
    31.         threads = tf.train.start_queue_runners(sess=sess, coord=coord)  
    32.   
    33.         # 因为我这里只有 2 张图片,所以下面循环 2 次  
    34.         for i in range(2):  
    35.             # 获取一张图片和其对应的类型  
    36.             label, image = sess.run([labels, images])  
    37.             # 这里特别说明下:  
    38.             #   因为要想把图片保存成 TFRecord,那就必须先将图片矩阵转换成 string,即:  
    39.             #       pic2tfrecords.py 中 image_raw = image.tostring() 这行  
    40.             #   所以这里需要执行下面这行将 string 转换回来,否则会无法 reshape 成图片矩阵,请看下面的小例子:  
    41.             #       a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩阵  
    42.             #       b = a.tostring()  
    43.             #       # 下面这行的输出是 32,即: 2*2 之后还要再乘 8  
    44.             #       # 如果 tostring 之后的长度是 2*2=4 的话,那可以将 b 直接 reshape([2, 2]),但现在的长度是 2*2*8 = 32,所以无法直接 reshape  
    45.             #       # 同理如果你的图片是 500*500*3 的话,那 tostring() 之后的长度是 500*500*3 后再乘上一个数  
    46.             #       print len(b)  
    47.             #  
    48.             #   但在网上有很多提供的代码里都没有下面这一行,你们那真的能 reshape ?  
    49.             image = np.fromstring(image, dtype=np.float32)  
    50.             # reshape 成图片矩阵  
    51.             image = tf.reshape(image, [224, 224, 3])  
    52.             # 因为要保存图片,所以将其转换成 uint8  
    53.             image = tf.image.convert_image_dtype(image, dtype=tf.uint8)  
    54.             # 按照 jpeg 格式编码  
    55.             image = tf.image.encode_jpeg(image)  
    56.             # 保存图片  
    57.             with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f:  
    58.                 f.write(sess.run(image))  
    59.   
    60.   
    61. load_data(224, 224)  


    train_list.txt_bak 中的内容如下:

    image_1093.jpg 13
    image_0805.jpg 10

  • 相关阅读:
    【python cookbook】找出序列中出现次数最多的元素
    2018/1/21 Netty通过解码处理器和编码处理器来发送接收POJO,Zookeeper深入学习
    读《风雨20年》小感
    两个知识点的回顾(const指针和动态链接库函数dlopen)
    小试牛刀
    chmod,chown和chgrp的区别
    node.js中使用node-schedule实现定时任务
    在 Node.js 上调用 WCF Web 服务
    nodejs发起HTTPS请求并获取数据
    openstack 之~云计算介绍
  • 原文地址:https://www.cnblogs.com/antflow/p/7299029.html
Copyright © 2011-2022 走看看