zoukankan      html  css  js  c++  java
  • Tensorflow 中(批量)读取数据的案列分析及TFRecord文件的打包与读取

    内容概要:

    单一数据读取方式:

      第一种:slice_input_producer()

    # 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...]
    [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

      第二种:string_input_producer()

    # 需要定义文件读取器,然后通过读取器中的 read()方法来获取数据(返回值类型 key,value),再通过 Session.run(value)查看
    file_queue = tf.train.string_input_producer(filename, num_epochs=None, shuffle=True)

    reader = tf.WholeFileReader() # 定义文件读取器
    key, value = reader.read(file_queue)    # key:文件名;value:文件中的内容

      !!!num_epochs=None,不指定迭代次数,这样文件队列中元素个数也不限定(None*数据集大小)。

      !!!如果不是None,则此函数创建本地计数器 epochs,需要使用local_variables_initializer()初始化局部变量

      !!!以上两种方法都可以生成文件名队列。

    (随机)批量数据读取方式:

    batchsize=2  # 每次读取的样本数量
    tf.train.batch(tensors, batch_size=batchsize)
    tf.train.shuffle_batch(tensors, batch_size=batchsize, capacity=batchsize*10, min_after_dequeue=batchsize*5) # capacity > min_after_dequeue

      !!!以上所有读取数据的方法,在Session.run()之前必须开启文件队列线程 tf.train.start_queue_runners()

     TFRecord文件的打包与读取

     一、单一数据读取方式

    第一种:slice_input_producer()

    def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None)

    案例1:

    import tensorflow as tf
    
    images = ['image1.jpg', 'image2.jpg', 'image3.jpg', 'image4.jpg']
    labels = [1, 2, 3, 4]
    
    # [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)
    
    # 当num_epochs=2时,此时文件队列中只有 2*4=8个样本,所有在取第9个样本时会出错
    # [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=2, shuffle=True)
    
    data = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)
    print(type(data))   # <class 'list'>
    
    with tf.Session() as sess:
        # sess.run(tf.local_variables_initializer())
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()  # 线程的协调器
        threads = tf.train.start_queue_runners(sess, coord)  # 开始在图表中收集队列运行器
    
        for i in range(10):
            print(sess.run(data))
    
        coord.request_stop()
        coord.join(threads)
    
    """
    运行结果:
    [b'image2.jpg', 2]
    [b'image1.jpg', 1]
    [b'image3.jpg', 3]
    [b'image4.jpg', 4]
    [b'image2.jpg', 2]
    [b'image1.jpg', 1]
    [b'image3.jpg', 3]
    [b'image4.jpg', 4]
    [b'image2.jpg', 2]
    [b'image3.jpg', 3]
    """

      !!!slice_input_producer() 中的第一个参数需要放在一个列表中,列表中的每个元素可以是 List 或 Tensor,如 [images,labels],

      !!!num_epochs设置

     第二种:string_input_producer()

    def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None)

    文件读取器

      不同类型的文件对应不同的文件读取器,我们称为 reader对象

      该对象的 read 方法自动读取文件,并创建数据队列,输出key/文件名,value/文件内容;

    reader = tf.TextLineReader()      ### 一行一行读取,适用于所有文本文件

    reader = tf.TFRecordReader() ### A Reader that outputs the records from a TFRecords file
    reader = tf.WholeFileReader() ### 一次读取整个文件,适用图片

     案例2:读取csv文件

    iimport tensorflow as tf
    
    filename = ['data/A.csv', 'data/B.csv', 'data/C.csv']
    
    file_queue = tf.train.string_input_producer(filename, shuffle=True, num_epochs=2)   # 生成文件名队列
    reader = tf.WholeFileReader()           # 定义文件读取器(一次读取整个文件)
    # reader = tf.TextLineReader()            # 定义文件读取器(一行一行的读)
    key, value = reader.read(file_queue)    # key:文件名;value:文件中的内容
    print(type(file_queue))
    
    init = [tf.global_variables_initializer(), tf.local_variables_initializer()]
    with tf.Session() as sess:
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            while not coord.should_stop():
                for i in range(6):
                    print(sess.run([key, value]))
                break
        except tf.errors.OutOfRangeError:
            print('read done')
        finally:
            coord.request_stop()
        coord.join(threads)
    
    """
    reader = tf.WholeFileReader()           # 定义文件读取器(一次读取整个文件)
    运行结果:
    [b'data/C.csv', b'7.jpg,7
    8.jpg,8
    9.jpg,9
    ']
    [b'data/B.csv', b'4.jpg,4
    5.jpg,5
    6.jpg,6
    ']
    [b'data/A.csv', b'1.jpg,1
    2.jpg,2
    3.jpg,3
    ']
    [b'data/A.csv', b'1.jpg,1
    2.jpg,2
    3.jpg,3
    ']
    [b'data/B.csv', b'4.jpg,4
    5.jpg,5
    6.jpg,6
    ']
    [b'data/C.csv', b'7.jpg,7
    8.jpg,8
    9.jpg,9
    ']
    """
    """
    reader = tf.TextLineReader()           # 定义文件读取器(一行一行的读)
    运行结果:
    [b'data/B.csv:1', b'4.jpg,4']
    [b'data/B.csv:2', b'5.jpg,5']
    [b'data/B.csv:3', b'6.jpg,6']
    [b'data/C.csv:1', b'7.jpg,7']
    [b'data/C.csv:2', b'8.jpg,8']
    [b'data/C.csv:3', b'9.jpg,9']
    """

    案例3:读取图片(每次读取全部图片内容,不是一行一行)

    import tensorflow as tf
    
    filename = ['1.jpg', '2.jpg']
    filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=1)
    reader = tf.WholeFileReader()              # 文件读取器
    key, value = reader.read(filename_queue)   # 读取文件 key:文件名;value:图片数据,bytes
    
    with tf.Session() as sess:
        tf.local_variables_initializer().run()
        coord = tf.train.Coordinator()      # 线程的协调器
        threads = tf.train.start_queue_runners(sess, coord)
    
        for i in range(filename.__len__()):
            image_data = sess.run(value)
            with open('img_%d.jpg' % i, 'wb') as f:
                f.write(image_data)
        coord.request_stop()
        coord.join(threads)

     二、(随机)批量数据读取方式:

      功能:shuffle_batch() 和 batch() 这两个API都是从文件队列中批量获取数据,使用方式类似;

    案例4:slice_input_producer() 与 batch()

    import tensorflow as tf
    import numpy as np
    
    images = np.arange(20).reshape([10, 2])
    label = np.asarray(range(0, 10))
    images = tf.cast(images, tf.float32)  # 可以注释掉,不影响运行结果
    label = tf.cast(label, tf.int32)     # 可以注释掉,不影响运行结果
    
    batchsize = 6   # 每次获取元素的数量
    input_queue = tf.train.slice_input_producer([images, label], num_epochs=None, shuffle=False)
    image_batch, label_batch = tf.train.batch(input_queue, batch_size=batchsize)
    
    # 随机获取 batchsize个元素,其中,capacity:队列容量,这个参数一定要比 min_after_dequeue 大
    # image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=batchsize, capacity=64, min_after_dequeue=10)
    
    with tf.Session() as sess:
        coord = tf.train.Coordinator()      # 线程的协调器
        threads = tf.train.start_queue_runners(sess, coord)     # 开始在图表中收集队列运行器
        for cnt in range(2):
            print("第{}次获取数据,每次batch={}...".format(cnt+1, batchsize))
            image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
            print(image_batch_v, label_batch_v, label_batch_v.__len__())
    
        coord.request_stop()
        coord.join(threads)
    
    """
    运行结果:
    第1次获取数据,每次batch=6...
    [[ 0.  1.]
     [ 2.  3.]
     [ 4.  5.]
     [ 6.  7.]
     [ 8.  9.]
     [10. 11.]] [0 1 2 3 4 5] 6
    第2次获取数据,每次batch=6...
    [[12. 13.]
     [14. 15.]
     [16. 17.]
     [18. 19.]
     [ 0.  1.]
     [ 2.  3.]] [6 7 8 9 0 1] 6
    """

     案例5:从本地批量的读取图片 --- string_input_producer() 与 batch()

     1 import tensorflow as tf
     2 import glob
     3 import cv2 as cv
     4 
     5 def read_imgs(filename, picture_format, input_image_shape, batch_size=1):
     6     """
     7     从本地批量的读取图片
     8     :param filename: 图片路径(包括图片的文件名),[]
     9     :param picture_format: 图片的格式,如 bmp,jpg,png等; string
    10     :param input_image_shape: 输入图像的大小; (h,w,c)或[]
    11     :param batch_size: 每次从文件队列中加载图片的数量; int
    12     :return: batch_size张图片数据, Tensor
    13     """
    14     global new_img
    15     # 创建文件队列
    16     file_queue = tf.train.string_input_producer(filename, num_epochs=1, shuffle=True)
    17     # 创建文件读取器
    18     reader = tf.WholeFileReader()
    19     # 读取文件队列中的文件
    20     _, img_bytes = reader.read(file_queue)
    21     # print(img_bytes)    # Tensor("ReaderReadV2_19:1", shape=(), dtype=string)
    22     # 对图片进行解码
    23     if picture_format == ".bmp":
    24         new_img = tf.image.decode_bmp(img_bytes, channels=1)
    25     elif picture_format == ".jpg":
    26         new_img = tf.image.decode_jpeg(img_bytes, channels=3)
    27     else:
    28         pass
    29     # 重新设置图片的大小
    30     # new_img = tf.image.resize_images(new_img, input_image_shape)
    31     new_img = tf.reshape(new_img, input_image_shape)
    32     # 设置图片的数据类型
    33     new_img = tf.image.convert_image_dtype(new_img, tf.uint8)
    34 
    35     # return new_img
    36     return tf.train.batch([new_img], batch_size)
    37 
    38 
    39 def main():
    40     image_path = glob.glob(r'F:demoFaceRecognition人脸库ORL*.bmp')
    41     image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
    42     print(type(image_batch))
    43     # image_path = glob.glob(r'.*.jpg')
    44     # image_batch = read_imgs(image_path, ".jpg", (313, 500, 3), 1)
    45 
    46     sess = tf.Session()
    47     sess.run(tf.local_variables_initializer())
    48     tf.train.start_queue_runners(sess=sess)
    49 
    50     image_batch = sess.run(image_batch)
    51     print(type(image_batch))    # <class 'numpy.ndarray'>
    52 
    53     for i in range(image_batch.__len__()):
    54         cv.imshow("win_"+str(i), image_batch[i])
    55     cv.waitKey()
    56     cv.destroyAllWindows()
    57 
    58 def start():
    59     image_path = glob.glob(r'F:demoFaceRecognition人脸库ORL*.bmp')
    60     image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
    61     print(type(image_batch))    # <class 'tensorflow.python.framework.ops.Tensor'>
    62 
    63 
    64     with tf.Session() as sess:
    65         sess.run(tf.local_variables_initializer())
    66         coord = tf.train.Coordinator()      # 线程的协调器
    67         threads = tf.train.start_queue_runners(sess, coord)     # 开始在图表中收集队列运行器
    68         image_batch = sess.run(image_batch)
    69         print(type(image_batch))    # <class 'numpy.ndarray'>
    70 
    71         for i in range(image_batch.__len__()):
    72             cv.imshow("win_"+str(i), image_batch[i])
    73         cv.waitKey()
    74         cv.destroyAllWindows()
    75 
    76         # 若使用 with 方式打开 Session,且没加如下2行语句,则会出错
    77         # ERROR:tensorflow:Exception in QueueRunner: Enqueue operation was cancelled;
    78         # 原因:文件队列线程还处于工作状态(队列中还有图片数据),而加载完batch_size张图片会话就会自动关闭,同时关闭文件队列线程
    79         coord.request_stop()
    80         coord.join(threads)
    81 
    82 
    83 if __name__ == "__main__":
    84     # main()
    85     start()
    从本地批量的读取图片案例

    案列6:TFRecord文件打包与读取

     1 def write_TFRecord(filename, data, labels, is_shuffler=True):
     2     """
     3     将数据打包成TFRecord格式
     4     :param filename: 打包后路径名,默认在工程目录下创建该文件;String
     5     :param data: 需要打包的文件路径名;list
     6     :param labels: 对应文件的标签;list
     7     :param is_shuffler:是否随机初始化打包后的数据,默认:True;Bool
     8     :return: None
     9     """
    10     im_data = list(data)
    11     im_labels = list(labels)
    12 
    13     index = [i for i in range(im_data.__len__())]
    14     if is_shuffler:
    15         np.random.shuffle(index)
    16 
    17     # 创建写入器,然后使用该对象写入样本example
    18     writer = tf.python_io.TFRecordWriter(filename)
    19     for i in range(im_data.__len__()):
    20         im_d = im_data[index[i]]    # im_d:存放着第index[i]张图片的路径信息
    21         im_l = im_labels[index[i]]  # im_l:存放着对应图片的标签信息
    22 
    23         # # 获取当前的图片数据  方式一:
    24         # data = cv2.imread(im_d)
    25         # # 创建样本
    26         # ex = tf.train.Example(
    27         #     features=tf.train.Features(
    28         #         feature={
    29         #             "image": tf.train.Feature(
    30         #                 bytes_list=tf.train.BytesList(
    31         #                     value=[data.tobytes()])), # 需要打包成bytes类型
    32         #             "label": tf.train.Feature(
    33         #                 int64_list=tf.train.Int64List(
    34         #                     value=[im_l])),
    35         #         }
    36         #     )
    37         # )
    38         # 获取当前的图片数据  方式二:相对于方式一,打包文件占用空间小了一半多
    39         data = tf.gfile.FastGFile(im_d, "rb").read()
    40         ex = tf.train.Example(
    41             features=tf.train.Features(
    42                 feature={
    43                     "image": tf.train.Feature(
    44                         bytes_list=tf.train.BytesList(
    45                             value=[data])), # 此时的data已经是bytes类型
    46                     "label": tf.train.Feature(
    47                         int64_list=tf.train.Int64List(
    48                             value=[im_l])),
    49                 }
    50             )
    51         )
    52 
    53         # 写入将序列化之后的样本
    54         writer.write(ex.SerializeToString())
    55     # 关闭写入器
    56     writer.close()
    TFRecord文件打包案列
     1 import tensorflow as tf
     2 import cv2
     3 
     4 def read_TFRecord(file_list, batch_size=10):
     5     """
     6     读取TFRecord文件
     7     :param file_list: 存放TFRecord的文件名,List
     8     :param batch_size: 每次读取图片的数量
     9     :return: 解析后图片及对应的标签
    10     """
    11     file_queue = tf.train.string_input_producer(file_list, num_epochs=None, shuffle=True)
    12     reader = tf.TFRecordReader()
    13     _, ex = reader.read(file_queue)
    14     batch = tf.train.shuffle_batch([ex], batch_size, capacity=batch_size * 10, min_after_dequeue=batch_size * 5)
    15 
    16     feature = {
    17         'image': tf.FixedLenFeature([], tf.string),
    18         'label': tf.FixedLenFeature([], tf.int64)
    19     }
    20     example = tf.parse_example(batch, features=feature)
    21 
    22     images = tf.decode_raw(example['image'], tf.uint8)
    23     images = tf.reshape(images, [-1, 32, 32, 3])
    24 
    25     return images, example['label']
    26 
    27 
    28 
    29 def main():
    30     # filelist = ['data/train.tfrecord']
    31     filelist = ['data/test.tfrecord']
    32     images, labels = read_TFRecord(filelist, 2)
    33     with tf.Session() as sess:
    34         sess.run(tf.local_variables_initializer())
    35         coord = tf.train.Coordinator()
    36         threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    37 
    38         try:
    39             while not coord.should_stop():
    40                 for i in range(1):
    41                     image_bth, _ = sess.run([images, labels])
    42                     print(_)
    43 
    44                     cv2.imshow("image_0", image_bth[0])
    45                     cv2.imshow("image_1", image_bth[1])
    46                 break
    47         except tf.errors.OutOfRangeError:
    48             print('read done')
    49         finally:
    50             coord.request_stop()
    51         coord.join(threads)
    52         cv2.waitKey(0)
    53         cv2.destroyAllWindows()
    54 
    55 if __name__ == "__main__":
    56     main()
    TFReord文件的读取案列
  • 相关阅读:
    第三天 moyax
    mkfs.ext3 option
    write file to stroage trigger kernel warning
    download fomat install rootfs script
    custom usb-seriel udev relus for compatible usb-seriel devices using kermit
    Wifi Troughput Test using iperf
    learning uboot switch to standby system using button
    learning uboot support web http function in qca4531 cpu
    learngin uboot design parameter recovery mechanism
    learning uboot auto switch to stanbdy system in qca4531 cpu
  • 原文地址:https://www.cnblogs.com/nbk-zyc/p/13159986.html
Copyright © 2011-2022 走看看