zoukankan      html  css  js  c++  java
  • Tensorflow踩坑系列---TFRECORD文件读写

    补充:TFRECORD文件学习

    https://blog.csdn.net/briblue/article/details/80789608

    import tensorflow as tf
    import os
    import random
    import sys
    #生成的tfrecord文件数量
    _NUM_BLOCK = 2
    #源图片位置
    DATASET_DIR = "./Images/SourceImgs/"
    #目标文件位置
    GOALSET_DIR = "./Images/TFImgs/"
    
    GOALTFSET_DIR = "./Images/TFSourceImgs/"
    
    QUEUE_DIR = "./Images/QueueImgs/"

    一:生成TFRECORD文件

    (一)获取图片信息

    #获取图片的信息
    def get_file_info(dataSet_dir = DATASET_DIR):
        files_name = []
        for filename in os.listdir(dataSet_dir):
            files_name.append(os.path.join(dataSet_dir,filename))
            
        return files_name

    (二)写入TFRECORD文件

    with tf.Session() as sess:
        files_list = get_file_info()
        num_per_block = len(files_list)//_NUM_BLOCK
        for _id in range(_NUM_BLOCK):
            tfr_name = "image_%d.tfrecord"%(_id+1)
            tfr_dir = os.path.join(GOALSET_DIR,tfr_name)
            with tf.python_io.TFRecordWriter(tfr_dir) as writer:
                start_idx = _id*num_per_block
                end_idx = min((_id+1)*num_per_block,len(files_list))
                
                for i in range(start_idx,end_idx):
                    try:
                        sys.stdout.write("
    >>Converting images %d/%d to block %d"%(i+1,len(files_list),_id+1))
                        sys.stdout.flush()
                        #读取图片信息
                        image_data = tf.gfile.FastGFile(files_list[i],'rb').read()
                        #获取标签
                        label = files_list[i].split("/")[-1].split(".")[0]
                        
                        example = _format_record(image_data,label)
                        
                        writer.write(example.SerializePartialToString())
                    except IOError as e:
                        print("Could not read:",files_list[i])
                        print("Error",e)
                        print("Skip it
    ")
        sys.stdout.write("
    ")
        sys.stdout.flush()

    二:直接读取TFRECORD文件

    (一)解析文件

    def _parse_record(example_proto):
        features = {
            'label':tf.FixedLenFeature((),tf.string),
            'data':tf.FixedLenFeature((),tf.string)
        }
        parsed_features = tf.parse_single_example(example_proto,features=features)
        return parsed_features

    (二)读取所有文件

    with tf.Session() as sess:
        tf_files = []
        for fn in os.listdir(GOALSET_DIR):
            tf_files.append(os.path.join(GOALSET_DIR,fn))
            
        dataSet = tf.data.TFRecordDataset(tf_files) #读取TF文件---可以选择一次性读取所有的tfrecord文件
        dataSet = dataSet.map(_parse_record) #解析数据
        
        iterator = dataSet.make_one_shot_iterator()
        
        sess.run(tf.local_variables_initializer())
        while True:
            try:
                Singledata = sess.run(iterator.get_next())
                label = Singledata['label'].decode()
                image_data = Singledata['data']
                tf.gfile.GFile(os.path.join(GOALTFSET_DIR,"%s.jpg"%label),"wb").write(image_data)
            except BaseException as e:
                print("Read finish!!!")
                break

    三:使用文件队列读取多个tfrecord文件

    tf_files = []
    for fn in os.listdir(GOALSET_DIR):
        tf_files.append(os.path.join(GOALSET_DIR,fn))
    
    #string_input_producer产生文件名队列
    filename_queue = tf.train.string_input_producer(tf_files,shuffle=True,num_epochs=3) #获取了多个tfrecord文件
    
    #reader从文件名队列中读取数据
    reader = tf.TFRecordReader()
    key,value = reader.read(filename_queue) #返回文件名和文件内容
    features = tf.parse_single_example(value,features={
            'label':tf.FixedLenFeature((),tf.string),
            'data':tf.FixedLenFeature((),tf.string)
        })
    img_data = features['data']
    label = features['label']
    
    image_batch,label_batch = tf.train.shuffle_batch([img_data,label],batch_size=8,num_threads=2,allow_smaller_final_batch=True,
                                                    capacity=500,min_after_dequeue=100) 
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer()) #初始化上面的全局变量
        sess.run(tf.local_variables_initializer()) #初始化上面的局部变量
        
        coord = tf.train.Coordinator()
        #启动start_queue_runners之后,才会开始填充队列
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)
        j = 1
        try:
            while not coord.should_stop():
                images_data,labels_data = sess.run([image_batch,label_batch])
                for i in range(len(images_data)):
                    with open(QUEUE_DIR+"%s-%d.jpg"%(labels_data[i].decode(),j),"wb") as f:
                        f.write(images_data[i])
                    j+=1
        except BaseException as e:
                print("read all files")
        finally:
            coord.request_stop() #将读取文件的线程关闭
        coord.join(threads) #线程回收,将读取文件的子线程加入主线程

  • 相关阅读:
    泛型的内部原理:类型擦除以及类型擦除带来的问题
    Redis的那些最常见面试问题
    线程池全面解析
    对线程调度中Thread.sleep(0)的深入理解
    集群环境下Redis分布式锁
    3.8
    3.7
    3.6任务
    3.5任务
    3.4
  • 原文地址:https://www.cnblogs.com/ssyfj/p/13976920.html
Copyright © 2011-2022 走看看