zoukankan      html  css  js  c++  java
  • tensorflow二进制文件读取与tfrecords文件读取

    1、知识点

    """
    TFRecords介绍:
        TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,它能更好的利用内存,
        更方便复制和移动,为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中
    
    CIFAR-10批处理结果存入tfrecords流程:
        1、构造存储器
             a)TFRecord存储器API:tf.python_io.TFRecordWriter(path) 写入tfrecords文件
                参数:   
                    path: TFRecords文件的路径
                    return:写文件
                方法:
                    write(record):向文件中写入一个字符串记录
                        record:字符串为一个序列化的Example,Example.SerializeToString()
                    close():关闭文件写入器
    
        2、构造每一个样本的Example协议块
             a)tf.train.Example(features=None)写入tfrecords文件
                    features:tf.train.Features类型的特征实例
                    return:example格式协议块
    
             b)tf.train.Features(feature=None)构建每个样本的信息键值对
                    feature:字典数据,key为要保存的名字,
                    value为tf.train.Feature实例
                    return:Features类型
    
             c)tf.train.Feature(**options)
                    **options:例如
                        bytes_list=tf.train.BytesList(value=[Bytes])
                        int64_list=tf.train.Int64List(value=[Value])
                    数据类型:
                        tf.train.Int64List(value=[Value])
                        tf.train.BytesList(value=[Bytes]) 
                        tf.train.FloatList(value=[value]) 
    
        3、写入序列化的Example
             writer.write(example.SerializeToString())
       
    报错: 
            1、ValueError: Protocol message Feature has no "Bytes_list" field.
                    因为没有Bytes_list属性字段,只有bytes_list字段
                    
    读取tfrecords流程:
        1、构建文件队列
            file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
        2、构造TFRecords阅读器
            reader = tf.TFRecordReader()
        3、解析Example,获取数据
            a) tf.parse_single_example(serialized,features=None,name=None)解析TFRecords的example协议内存块
                serialized:标量字符串Tensor,一个序列化的Example
                features:dict字典数据,键为读取的名字,值为FixedLenFeature
                return:一个键值对组成的字典,键为读取的名字
            b)tf.FixedLenFeature(shape,dtype) 类型只能是float32,int64,string
                shape:输入数据的形状,一般不指定,为空列表
                dtype:输入数据类型,与存储进文件的类型要一致     
        4、转换格式,bytes解码
            image = tf.decode_raw(features["image"],tf.uint8)
            #固定图像大小,有利于批处理操作
            image_reshape = tf.reshape(image,[self.height,self.width,self.channel])
            label = tf.cast(features["label"],tf.int32)
        5、批处理
            image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20)
    
    报错:
        1、ValueError: Shape () must have rank at least 1
            
    """

    2、代码

    # coding = utf-8
    import tensorflow as tf
    import  os
    
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string("cifar_dir","./cifar10/", "文件的目录")
    tf.app.flags.DEFINE_string("cifar_tfrecords", "./tfrecords/cifar.tfrecords", "存进tfrecords的文件")
    class CifarRead(object):
        """
        完成读取二进制文件,写进tfrecords,读取tfrecords
        """
        def __init__(self,file_list):
            self.file_list = file_list
            #图片属性
            self.height = 32
            self.width = 32
            self.channel = 3
    
            #二进制字节
            self.label_bytes = 1
            self.image_bytes = self.height*self.width*self.channel
            self.bytes = self.label_bytes + self.image_bytes
    
    
        def read_and_encode(self):
            """
            读取二进制文件,并进行解码操作
            :return:
            """
            #1、创建文件队列
            file_quque = tf.train.string_input_producer(self.file_list)
            #2、创建阅读器,读取二进制文件
            reader = tf.FixedLengthRecordReader(self.bytes)
            key, value = reader.read(file_quque)#key为文件名,value为文件内容
            #3、解码操作
            label_image = tf.decode_raw(value,tf.uint8)
    
            #分割图片和标签数据, tf.cast(),数据类型转换   tf.slice()tensor数据进行切片
            label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
            image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
    
            #对图像进行形状改变
            image_reshape = tf.reshape(image,[self.height,self.width,self.channel])
    
            # 4、批处理操作
            image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20)
            print(image_batch,label_batch)
            return image_batch,label_batch
    
        def write_ro_tfrecords(self,image_batch,label_batch):
            """
            将读取的二进制文件写入 tfrecords文件中
            :param image_batch: 图像 (32,32,3)
            :param label_batch: 标签
            :return:
            """
            # 1、构造存储器
            writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)
    
            #循环写入
            for i in range(5):
                image = image_batch[i].eval().tostring()
                label = int(label_batch[i].eval()[0])
                # 2、构造每一个样本的Example
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])) ,
                    "label":tf.train.Feature(int64_list = tf.train.Int64List(value = [label])) ,
                }))
    
                # 3、写入序列化的Example
                writer.write(example.SerializeToString())
    
            #关闭流
            writer.close()
            return None
    
        def read_from_tfrecords(self):
            """
            从tfrecords文件读取数据
            :return:
            """
            #1、构建文件队列
            file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
            #2、构造TFRecords阅读器
            reader = tf.TFRecordReader()
            key , value = reader.read(file_queue)
            #3、解析Example
            features = tf.parse_single_example(value,features={
                "image":tf.FixedLenFeature([],tf.string),
                "label":tf.FixedLenFeature([],tf.int64)
            })
            #4、解码内容, 如果读取的内容格式是string需要解码, 如果是int64,float32不需要解码
            image = tf.decode_raw(features["image"],tf.uint8)
            #固定图像大小,有利于批处理操作
            image_reshape = tf.reshape(image,[self.height,self.width,self.channel])
            label = tf.cast(features["label"],tf.int32)
    
            #5、批处理
            image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20)
            return image_batch,label_batch
    
    
    if __name__ == '__main__':
        #################二进制文件读取###############
        # file_name = os.listdir(FLAGS.cifar_dir)
        # file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
        # cf = CifarRead(file_list)
        # image_batch, label_batch = cf.read_and_encode()
        # with tf.Session() as sess:
        #     # 创建协调器
        #     coord = tf.train.Coordinator()
        #     # 开启线程
        #     threads = tf.train.start_queue_runners(sess, coord=coord)
        #
        #     print(sess.run([image_batch, label_batch]))
        #     # 回收线程
        #     coord.request_stop()
        #     coord.join(threads)
        #############################################
    
        #####二进制文件读取,并写入tfrecords文件######
        # file_name = os.listdir(FLAGS.cifar_dir)
        # file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
        # cf = CifarRead(file_list)
        # image_batch, label_batch = cf.read_and_encode()
        # with tf.Session() as sess:
        #     # 创建协调器
        #     coord = tf.train.Coordinator()
        #     # 开启线程
        #     threads = tf.train.start_queue_runners(sess, coord=coord)
        #     #########保存文件到tfrecords##########
        #     cf.write_ro_tfrecords(image_batch, label_batch)
        #     #########保存文件到tfrecords##########
        #
        #     print(sess.run([image_batch, label_batch]))
        #     # 回收线程
        #     coord.request_stop()
        #     coord.join(threads)
        ##############################################
    
        #############从tfrecords文件读取###############
        file_name = os.listdir(FLAGS.cifar_dir)
        file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
        cf = CifarRead(file_list)
        image_batch, label_batch = cf.read_from_tfrecords()
        with tf.Session() as sess:
            # 创建协调器
            coord = tf.train.Coordinator()
            # 开启线程
            threads = tf.train.start_queue_runners(sess, coord=coord)
    
            print(sess.run([image_batch, label_batch]))
            # 回收线程
            coord.request_stop()
            coord.join(threads)
        ##############################################
  • 相关阅读:
    从安装.net Core 到helloWord(Mac上)
    阿里云-对象储存OSS
    图片处理
    项目中 添加 swift代码 真机调试 错误
    iOS面试总结
    IOS APP配置.plist汇总
    cocoapods安装问题
    iOS8使用UIVisualEffectView实现模糊效果
    ios键盘回收终极版
    ?C++ 缺少参数的默认参数
  • 原文地址:https://www.cnblogs.com/ywjfx/p/10919461.html
Copyright © 2011-2022 走看看