zoukankan      html  css  js  c++  java
  • 机器学习: TensorFlow 的数据读取与TFRecords 格式

    最近学习tensorflow,发现其读取数据的方式看起来有些不同,所以又重新系统地看了一下文档,总得来说,tensorflow 有三种主流的数据读取方式:
    1) 传送 (feeding): Python 可以在程序的运行过程中,将数据传送进定义好的 tensor 变量中
    2) 从文件读取 (reading from files): 一个输入流从文件中直接读取数据
    3) 预加载数据 (preloaded data): 这个很好理解,就是将所有的数据一次性全部读进内存里。

    对于第三种方式,在数据量小的时候,是非常高效的,但是如果数据量很大的时候,这种方法显然非常耗内存,所以在数据量很大的时候,一般选择第二种读取方式,即从文件读取。在利用第二种方式读取的时候,我们常常会用到一种 TFRecords 的格式来保存读取的文件。TFRecords 是一种二进制文件。可以在TensorFlow 中方便的进行各种存取操作以及预处理。

    我们先来看看,如何将一张图片转换成字符流

    import os
    import tensorflow as tf
    import matplotlib.pyplot as plt
    import numpy as np
    import skimage.io as io
    
    dir_path = 'Face'
    file_list = os.listdir(dir_path)
    
    print file_list
    
    for f in file_list:
        print (dir_path + os.sep + f)
    
    img_1 = io.imread(dir_path + os.sep + file_list[0])
    
    #plt.imshow(img_1, cmap='gray')
    #plt.show()
    
    # 将图像转换成字符
    img_str = img_1.tostring()
    
    # 将字符流还原成图像
    img_rec_vec = np.fromstring(img_str, dtype=np.uint8)
    
    img_rec = img_rec_vec.reshape(img_1.shape)
    
    #plt.imshow(img_rec, cmap='gray')
    #plt.show()
    

    接下来,我们看看,如何生成 TFRecords 文件:

    
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    tfrecords_filename = 'Face.tfrecords'
    
    writer = tf.python_io.TFRecordWriter(tfrecords_filename)
    
    for img_path in file_list:
        img = np.array(io.imread(dir_path + os.sep + img_path))
    
        # 从文件夹里读取图像
        # 获取图像的宽和高,图像的维度需要存入 TFRecords 文件中
        # 以方便后续的处理
        # 
        height = img.shape[0]
        width = img.shape[1]
    
        # 将图像转换成字符流
        img_raw = img.tostring()
    
        # 将字符流以及图像的尺度信息存入TFRecords 文件
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(height),
            'width': _int64_feature(width),
            'image_raw': _bytes_feature(img_raw)}))
    
        writer.write(example.SerializeToString())
    
    writer.close()

    最后,我们看看如何从 TFrecords 文件中读数据,并且做批处理:

    
    # 可以重新定义图像的宽和高,
    IMAGE_HEIGHT = 224
    IMAGE_WIDTH = 224
    
    # 定义读取与解码函数
    def read_and_decode(filename_queue):
    
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
    
    # 获取 features,包含图像,以及图像宽和高
        features = tf.parse_single_example(
            serialized_example,
            features={
                'height': tf.FixedLenFeature([], tf.int64),
                'width': tf.FixedLenFeature([], tf.int64),
                'image_raw': tf.FixedLenFeature([], tf.string),
            })
    
    # 获取图像信息
        image = tf.decode_raw(features['image_raw'], tf.uint8)
    
        height = tf.cast(features['height'], tf.int32)
        width = tf.cast(features['width'], tf.int32)
    
    # 将图像转换成多维数组的形式
        image_shape = [height, width, 1]
        image = tf.reshape(image, image_shape)
    
    # 重新定义图像的尺度 
        image_size_const = tf.constant((IMAGE_HEIGHT, IMAGE_WIDTH, 1), dtype=tf.int32)
    
        # Random transformations can be put here: right before you crop images
        # to predefined size. To get more information look at the stackoverflow
        # question linked above.
    
    # 对图像进行预处理,包括裁剪,增边等
    
        resized_image = tf.image.resize_image_with_crop_or_pad(image=image,
                                                               target_height=IMAGE_HEIGHT,
                                                               target_width=IMAGE_WIDTH)
    
        return resized_image
    
    
    # 
    filename_queue = tf.train.string_input_producer(
        [tfrecords_filename], num_epochs=10)
    
    # Even when reading in multiple threads, share the filename
    # queue.
    train_images = read_and_decode(filename_queue)
    
    # 要注意 min_after_dequeue 不能超过 capacity
    image = tf.train.shuffle_batch([train_images],
                                    batch_size=1,
                                    capacity=5,
                                    num_threads=2,
                                    min_after_dequeue=1)
    
    # The op for initializing the variables.
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    
    with tf.Session() as sess:
        sess.run(init_op)
    
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
    
        # Let's read off 3 batches just for example
        for i in xrange(1):
            img = sess.run([image])
            img_batch = img[0]
            img_1 = tf.reshape(img_batch[0, :, :, :], [IMAGE_HEIGHT, IMAGE_WIDTH])
            print (img_1.shape)
            plt.imshow(sess.run(img_1), cmap='gray')
    #    coord.request_stop()
    #    coord.join(threads)
    
    plt.show()
    print 'all is well'

    参考来源:

    http://codecloud.net/16485.html

    http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/

    https://www.tensorflow.org/programmers_guide/reading_data

  • 相关阅读:
    分布式跟踪工具pinpoint
    python调用阿里云产品接口实现自动发现异常访问ip并禁用2小时
    centos病毒
    Google Earth API开发者指南
    在vs中使用ZedGraph控件的一些记录
    A flexible charting library for .NET
    ZedGraph.dll
    WPF 动态模拟CPU 使用率曲线图
    C#调用GoogleEarth COM API开发
    使用WeifenLuo.WinFormsUI.Docking界面布局中的保存配置
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9412422.html
Copyright © 2011-2022 走看看