zoukankan      html  css  js  c++  java
  • TensorFlow中读取图像数据的三种方式

      本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片、大量图片,和TFRecorder读取方式。并且还补充了功能相近的tf函数。

    1、处理单张图片

      我们训练完模型之后,常常要用图片测试,有的时候,我们并不需要对很多图像做测试,可能就是几张甚至一张。这种情况下没有必要用队列机制。

    import tensorflow as tf
    import matplotlib.pyplot as plt
    
    def read_image(file_name):
        img = tf.read_file(filename=file_name)     # 默认读取格式为uint8
        print("img 的类型是",type(img));
        img = tf.image.decode_jpeg(img,channels=0) # channels 为1得到的是灰度图,为0则按照图片格式来读
        return img
    
    def main( ):
        with tf.device("/cpu:0"):
             # img_path是文件所在地址包括文件名称,地址用相对地址或者绝对地址都行 
                img_path='./1.jpg'
                img=read_image(img_path)
                with tf.Session() as sess:
                image_numpy=sess.run(img)
                print(image_numpy)
                print(image_numpy.dtype)
                print(image_numpy.shape)
                plt.imshow(image_numpy)
                plt.show()
    
    if __name__=="__main__":
        main()
    
    """
    输出结果为:
    
    img 的类型是 <class 'tensorflow.python.framework.ops.Tensor'>
    [[[196 219 209]
      [196 219 209]
      [196 219 209]
      ...
    
     [[ 71 106  42]
      [ 59  89  39]
      [ 34  63  19]
      ...
      [ 21  52  46]
      [ 15  45  43]
      [ 22  50  53]]]
    uint8
    (675, 1200, 3)
    """

       和tf.read_file用法相似的函数还有tf.gfile.FastGFile  tf.gfile.GFile,只是要指定读取方式是'r' 还是'rb' 。

    2、需要读取大量图像用于训练

      这种情况就需要使用Tensorflow队列机制。首先是获得每张图片的路径,把他们都放进一个list里面,然后用string_input_producer创建队列,再用tf.WholeFileReader读取。具体请看下例:

    def get_image_batch(data_file,batch_size):
        data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
     
        #这个num_epochs函数在整个Graph是local Variable,所以在sess.run全局变量的时候也要加上局部变量。  
        filenames_queue=tf.train.string_input_producer(data_names,num_epochs=50,shuffle=True,capacity=512)
        reader=tf.WholeFileReader()
        _,img_bytes=reader.read(filenames_queue)
        image=tf.image.decode_png(img_bytes,channels=1)    #读取的是什么格式,就decode什么格式
        #解码成单通道的,并且获得的结果的shape是[?, ?,1],也就是Graph不知道图像的大小,需要set_shape
        image.set_shape([180,180,1])   #set到原本已知图像的大小。或者直接通过tf.image.resize_images,tf.reshape()
        image=tf.image.convert_image_dtype(image,tf.float32)
        #预处理  下面的一句代码可以换成自己想使用的预处理方式
        #image=tf.divide(image,255.0)   
        return tf.train.batch([image],batch_size) 

      这里的date_file是指文件夹所在的路径,不包括文件名。第一句是遍历指定目录下的文件名称,存放到一个list中。当然这个做法有很多种方法,比如glob.glob,或者tf.train.match_filename_once

    全部代码如下:

    import tensorflow as tf
    import os
    def read_image(data_file,batch_size):
        data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
        filenames_queue=tf.train.string_input_producer(data_names,num_epochs=5,shuffle=True,capacity=30)
        reader=tf.WholeFileReader()
        _,img_bytes=reader.read(filenames_queue)
        image=tf.image.decode_jpeg(img_bytes,channels=1)
        image=tf.image.resize_images(image,(180,180))
    
        image=tf.image.convert_image_dtype(image,tf.float32)
        return tf.train.batch([image],batch_size)
    
    def main( ):
        img_path=r'F:dataSetWIDERWIDER_trainimages6--Funeral'  #本地的一个数据集目录,有足够的图像
        img=read_image(img_path,batch_size=10)
        image=img[0]  #取出每个batch的第一个数据
        print(image)
        init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
        with tf.Session() as sess:
            sess.run(init)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess,coord=coord)
            try:
                while not coord.should_stop():
                    print(image.shape)
            except tf.errors.OutOfRangeError:
                print('read done')
            finally:
                coord.request_stop()
            coord.join(threads)
    
    
    if __name__=="__main__":
        main()
    
    """
    输出如下:
    (180, 180, 1)
    (180, 180, 1)
    (180, 180, 1)
    (180, 180, 1)
    (180, 180, 1)
    """

      这段代码可以说写的很是规整了。注意到init里面有对local变量的初始化,并且因为用到了队列,当然要告诉电脑什么时候队列开始, tf.train.Coordinator 和 tf.train.start_queue_runners 就是两个管理队列的类,用法如程序所示。

      与 tf.train.string_input_producer相似的函数是 tf.train.slice_input_producer。 tf.train.slice_input_producer和tf.train.string_input_producer的第一个参数形式不一样。等有时间再做一个二者比较的博客

     3、对TFRecorder解码获得图像数据

      其实这块和上一种方式差不多的,更重要的是怎么生成TFRecorder文件,这一部分我会补充到另一篇博客上。

      仍然使用 tf.train.string_input_producer。

    import tensorflow as tf
    import matplotlib.pyplot as plt
    import os
    import cv2
    import  numpy as np
    import glob
    
    def read_image(data_file,batch_size):
        files_path=glob.glob(data_file)
        queue=tf.train.string_input_producer(files_path,num_epochs=None)
        reader = tf.TFRecordReader()
        print(queue)
        _, serialized_example = reader.read(queue)
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image_raw': tf.FixedLenFeature([], tf.string),
                'label_raw': tf.FixedLenFeature([], tf.string),
            })
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        image = tf.cast(image, tf.float32)
        image.set_shape((12*12*3))
        label = tf.decode_raw(features['label_raw'], tf.float32)
        label.set_shape((2))
        # 预处理部分省略,大家可以自己根据需要添加
        return tf.train.batch([image,label],batch_size=batch_size,num_threads=4,capacity=5*batch_size)
    
    def main( ):
        img_path=r'F:pythonMTCNN_by_myselfprepare_datapnet*.tfrecords'  #本地的几个tf文件
        img,label=read_image(img_path,batch_size=10)
        image=img[0]
        init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
        with tf.Session() as sess:
            sess.run(init)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess,coord=coord)
            try:
                while not coord.should_stop():
                    print(image.shape)
            except tf.errors.OutOfRangeError:
                print('read done')
            finally:
                coord.request_stop()
            coord.join(threads)
    
    
    if __name__=="__main__":
        main()

      在read_image函数中,先使用glob函数获得了存放tfrecord文件的列表,然后根据TFRecord文件是如何存的就如何parse,再set_shape;这里有必要提醒下parse的方式。我们看到这里用的是tf.decode_raw ,因为做TFRecord是将图像数据string化了,数据是串行的,丢失了空间结果。从features中取出image和label的数据,这时就要用 tf.decode_raw  解码,得到的结果当然也是串行的了,所以set_shape 成一个串行的,再reshape。这种方式是取决于你的编码TFRecord方式的。

    再举一种例子:

    reader=tf.TFRecordReader()
    _,serialized_example=reader.read(file_name_queue)
    features = tf.parse_single_example(serialized_example, features={
        'data': tf.FixedLenFeature([256,256], tf.float32), ###
        'label': tf.FixedLenFeature([], tf.int64),
        'id': tf.FixedLenFeature([], tf.int64)
    })
    img = features['data']
    label =features['label']
    id = features['id']

      这个时候就不需要任何解码了。因为做TFRecord的方式就是直接把图像数据append进去了。

    参考链接:

      https://blog.csdn.net/qq_34914551/article/details/86286184

  • 相关阅读:
    缓存三大问题及解决方案
    布隆过滤器
    maven 详解
    Netty是什么?
    select、poll、epoll简介
    IO
    Java并发编程:Synchronized及其实现原理
    Java并发之AQS详解
    原子更新字段类
    AtomicStampedReference解决ABA问题
  • 原文地址:https://www.cnblogs.com/nbk-zyc/p/13168313.html
Copyright © 2011-2022 走看看