zoukankan      html  css  js  c++  java
  • 『TensorFlow』TFR数据预处理探究以及框架搭建

    一、TFRecord文件书写效率对比(单线程和多线程对比)

    1、准备工作

    # Author : Hellcat
    # Time   : 18-1-15
    
    '''
    import os
    os.environ["CUDA_VISIBLE_DEVICES"]="-1" 
    '''
    
    import os
    import glob
    import numpy as np 
    import tensorflow as tf
    import matplotlib.pyplot as plt
    
    np.set_printoptions(threshold=np.inf)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    
    def _int64_feature(value):
        """生成整数数据属性"""
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    
    def _bytes_feature(value):
        """生成字符型数据属性"""
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    2、单线程TFR文件写入

    def image2TFR_single_thread(path='./Data_Set/cartoon_faces',with_label=False):
    
        # 获取图片名称以及数量
        # 等价于image_names = glob.glob(path+'/*')
        # 使用next可以直接取出迭代器中的元素
        image_names = next(os.walk(path))[2]
        num_file = len(image_names)
    
        # 定义每个文件中放入多少数据
        instances_per_shard = 10000
        # 定义写多少个文件(数据量大时可以写入多个文件加速)
        num_shards = num_file // instances_per_shard + 1
    
        for file_i in range(num_shards):
            # 文件名命名规则
            file_name = './TFRecord_Output/{0}.tfrecords_{1}_of_{2}_st'
                .format(path.split('/')[-1], file_i+1, num_shards)
            # 书写器初始化
            writer = tf.python_io.TFRecordWriter(file_name)
            for index, image_name in enumerate(
                    image_names[file_i*instances_per_shard:(file_i+1)*instances_per_shard]):
                image_data = plt.imread(os.path.join(path, image_name))
                if with_label == True:
                    pass
                    # TODO
                    # 如果有标签,则在这里添加确定标签的规则,注意非one_hot
                    # label = ……
                image_raw = image_data.tostring()
                example = tf.train.Example(features=tf.train.Features(feature={
                    'image': _bytes_feature(image_raw),
                    # 'label': _int64_feature(label)
                }))
                writer.write(example.SerializeToString())
            # 书写器关闭
            writer.close()

    3、多线程TFR文件写入

    def image2TFR_multiple_threads(path='./Data_Set/cartoon_faces',with_label=False):
    
        # 获取图片名称以及数量
        # 等价于image_names = glob.glob(path+'/*')
        # 使用next可以直接取出迭代器中的元素
        image_names = next(os.walk(path))[2]
        num_file = len(image_names)
    
        # 定义每个文件中放入多少数据
        instances_per_shard = 10000
        # 定义写多少个文件(数据量大时可以写入多个文件加速)
        num_shards = num_file // instances_per_shard + 1
    
        file_names = ['./TFRecord_Output/{0}.tfrecords_{1}_of_{2}_mt'
                          .format(path.split('/')[-1], file_i+1, num_shards) for file_i in range(num_shards)]
    
        def _TFR_write():
            for file_name in file_names:
                file_names.remove(file_name)
                writer = tf.python_io.TFRecordWriter(file_name)
                num = 0
                for image_name in image_names:
                    num += 1
                    if num > instances_per_shard:
                        break
                    image_names.remove(image_name)
                    image_data = plt.imread(os.path.join(path, image_name))
                    if with_label == True:
                        pass
                        # TODO
                        # 如果有标签,则在这里添加确定标签的规则,注意非one_hot
                        # label = ……
                    image_raw = image_data.tostring()
                    example = tf.train.Example(features=tf.train.Features(feature={
                        'image': _bytes_feature(image_raw),
                        # 'label': _int64_feature(label)
                    }))
                    writer.write(example.SerializeToString())
                writer.close()
    
        threads = []
        t1 = threading.Thread(target=_TFR_write, name='resize_img_thread:0')
        threads.append(t1)
        t2 = threading.Thread(target=_TFR_write, name='resize_img_thread:1')
        threads.append(t2)
    
        for t in threads:
            t.start()
        for t in threads:
            t.join()

    4、测试部分

    if __name__=='__main__':
        import datetime
        import threading
        for i in range(15):
            time1 = datetime.datetime.now()
            image2TFR_multiple_threads()
            time2 = datetime.datetime.now()
            image2TFR_single_thread()
            time3 = datetime.datetime.now()
            print('mul:', time2-time1)
            print('sin:', time3-time2)
            print('_*_'*10)

    5、部分输出

    mul: 0:00:25.779139
    sin: 0:00:26.312438
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:27.203649
    sin: 0:00:27.982487
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:31.193418
    sin: 0:00:28.735610
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:28.414592
    sin: 0:00:30.207631
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:27.999488
    sin: 0:00:29.683136
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:28.659919
    sin: 0:00:28.534984
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:30.366691
    sin: 0:00:31.014559
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:28.288918
    sin: 0:00:29.142247
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:29.861579
    sin: 0:00:29.329732
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:28.854213
    sin: 0:00:33.794422
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:28.010327
    sin: 0:00:29.163616
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:27.773299
    sin: 0:00:29.312738
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:27.815851
    sin: 0:00:28.715579
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:27.889409
    sin: 0:00:28.157235
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:28.143782
    sin: 0:00:28.988136
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:27.533430
    sin: 0:00:30.000925
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:28.158601
    sin: 0:00:29.448665
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:27.839638
    sin: 0:00:28.908899
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:27.922513
    sin: 0:00:28.757721
    _*__*__*__*__*__*__*__*__*__*_
    mul: 0:00:31.227687
    sin: 0:00:29.576041
    _*__*__*__*__*__*__*__*__*__*_

    可能是数据量不够大的原因,多线程没有明显的优势,可能写入文件数增加会更好,但个人感觉由于涉及到写入文件句柄操作这不是个适合使用多线程加速的任务。

    二、TFRecord实际使用框架

    总的原则,把可以修改的超参数啊、路径啊什么的单独提出来,不要放在程序中,那样使用时想要修改会及其繁琐,且易出错

    1、包导入以及超参数设定

    # Author : Hellcat
    # Time   : 18-1-15
    
    """
    import os
    os.environ["CUDA_VISIBLE_DEVICES"]="-1" 
    """
    
    import os
    import glob
    import numpy as np 
    import tensorflow as tf
    from scipy.misc import imread, imresize
    
    np.set_printoptions(threshold=np.inf)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    
    # 读取数据文件的轮数
    NUM_EPOCHS = 1
    # TFR保存图像尺寸
    IMAGE_HEIGHT = 227
    IMAGE_WIDTH = 227
    IMAGE_DEPTH = 3
    # 训练batch尺寸
    BATCH_SIZE = 2
    # 定义每个TFR文件中放入多少条数据
    INSTANCES_PER_SHARD = 10000
    # 图片文件存放路径
    IMAGE_PATH = './Data_Set/cartoon_faces'
    # 图片文件和标签清单保存文件
    IMAGE_LABEL_LIST = 'images_&_labels.txt'
    # TFR文件保存路径
    TFR_PATH = './TFRecord_Output'

    2、文件清单生成

    def filename_list(path=IMAGE_PATH):
        """
        文件清单生成
        :param path:图像路径,path下直接是图片 
        :return: txt文件,每一行内容是:路径图片名+若干空格+类别标签数字+
    
        """
        # 获取图片名称以及数量
        # 等价于image_names = glob.glob(path+'/*')
        # 使用next可以直接取出迭代器中的元素
        file_names = next(os.walk(path))[2]
        with open(IMAGE_LABEL_LIST, 'w') as f:
            for file_name in file_names:
                f.write(path+'/'+file_name+' '+'1'+'
    ')

    3、TFR文件生成

    def image_to_TFR(image_and_label=IMAGE_LABEL_LIST,
                     image_height=IMAGE_HEIGHT,
                     image_width=IMAGE_WIDTH):
        """
        从清单读取图片并生成TFR文件
        :param image_and_label: txt图片清单
        :param image_height: 保存如TFR文件的图片高度
        :param image_ 保存TFR文件的图片宽度
        """
        def _int64_feature(value):
            """生成整数数据属性"""
            return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
        def _bytes_feature(value):
            """生成字符型数据属性"""
            return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
        with open(image_and_label, 'r') as f:
            lines = f.readlines()
            image_paths = [image_path.strip('
    ').split(' ')[0] for image_path in lines]
            labels = [image_path.strip('
    ').split(' ')[-1] for image_path in lines]
    
            # 如下操作会报错,因为忽略了指针问题,第一次readlines后指针到达文件末尾,第二次readlines什么都read不到
            # image_paths = [image_path.strip('
    ').split(' ')[0] for image_path in f.readlines()]
            # labels = [image_path.strip('
    ').split(' ')[-1] for image_path in f.readlines()]
    
        num_file = len(image_paths)
        # 定义写多少个文件(数据量大时可以写入多个文件加速)
        num_shards = num_file // INSTANCES_PER_SHARD + 1
    
        for file_i in range(num_shards):
            # 文件名命名规则
            file_name = os.path.join(TFR_PATH, '{0}.tfrecords_{1}_of_{2}')
                .format(image_paths[0].split('/')[-2], file_i+1, num_shards)
            print('正在生成文件: ', file_name)
            # 书写器初始化
            writer = tf.python_io.TFRecordWriter(file_name)
            for index, image_path in enumerate(
                    image_paths[file_i*INSTANCES_PER_SHARD:(file_i+1)*INSTANCES_PER_SHARD]):
                image_data = imread(os.path.join(image_path))
                image_data = imresize(image_data, (image_height, image_width))
                image_raw = image_data.tostring()
                example = tf.train.Example(features=tf.train.Features(feature={
                    'image': _bytes_feature(image_raw),
                    'label': _int64_feature(int(labels[index]))
                }))
                writer.write(example.SerializeToString())
            # 书写器关闭
            writer.close()

    4、读取TFR文件并生成batch数据

    本函数最后的images和labels可以作为return,直接送入网络参与训练

    def batch_from_TFR(image_height=IMAGE_HEIGHT,
                       image_width=IMAGE_WIDTH,
                       image_depth=IMAGE_DEPTH):
        """从TFR文件读取batch数据"""
    
        if not os.path.exists(TFR_PATH):
            os.makedirs(TFR_PATH)
    
        '''读取TFR数据并还原为uint8的图片'''
        file_names = glob.glob(os.path.join(TFR_PATH, '{0}.tfrecords_*_of_*')
                               .format(IMAGE_PATH.split('/')[-1]))
        filename_queue = tf.train.string_input_producer(file_names, num_epochs=NUM_EPOCHS, shuffle=True)
    
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.int64)
            })
        image = features['image']
        image_decode = tf.decode_raw(image, tf.uint8)
        # 解码会变为一维数组,所以这里设定shape时需要设定为一维数组
        image_decode.set_shape([image_height*image_width*image_depth])
        image_decode = tf.reshape(image_decode, [image_height, image_width, image_depth])
        label = tf.cast(features['label'], tf.int32)
    
        '''图像预处理'''
    
        '''生成batch图像'''
        # 随机获得batch_size大小的图像和label
        images, labels = tf.train.shuffle_batch([image_decode, label],
                                                batch_size=BATCH_SIZE,
                                                num_threads=1,
                                                capacity=1000 + 3 * BATCH_SIZE,  # 队列最大容量
                                                min_after_dequeue=1000)

    5、包含在上面batch函数中的测试模块

        # 测试部分
        print(images)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        img = sess.run(images)[0]
        import matplotlib.pyplot as plt
        plt.imshow(img)
        coord.request_stop()
        coord.join(threads)
    

     测试结果,

    6、启动部分

    if __name__ == '__main__':
    
        import datetime
        time1 = datetime.datetime.now()
        # filename_list()
        # image_to_TFR()
        batch_from_TFR()
        time2 = datetime.datetime.now()
        print(time2-time1)
    

     从测试部分的运行注意到设计tf的队列操作时,局部变量初始化sess.run(tf.global_variables_initializer())是必须的,否则会报错(『TensorFlow』问题整理)。

  • 相关阅读:
    开发中的问题
    页面重定向Redirect时产生错误
    项目管理的几个阶段及分工
    让你的CSS像Jquery一样做筛选
    项目中的几个SQL程序
    SharePoint2010人员搜索配置心得
    TroubleShoot:该搜索请求无法连接到搜索服务
    转:软件架构师应该知道的97件事
    通用动态生成静态HTML页方法
    简单的正则表达式过滤网址
  • 原文地址:https://www.cnblogs.com/hellcat/p/8287831.html
Copyright © 2011-2022 走看看