zoukankan      html  css  js  c++  java
  • 第十一节 TFRecords读取

    import tensorflow as tf
    import os
    
    """
    TFRecords数据格式:是TensorFlow设计的一种内置文件格式,是一种二进制文件,它能更好的利用内存,更方便的复制和移动
    为了解决将二进制数据和标签(训练的类别标签)数据存储在不同同一个文件中的问题,TFRecords会将目标值和特征值合并在一个样本中
    文件格式:*.tfrecords
    写入文件内容:Example协议块,是一种类字典格式
    """
    # 训练数据连接:http://www.cs.toronto.edu/~kriz/cifar.html
    # 定义cifar的数据命令行参数,注意路径要写绝对路径,不然可能报错
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string("cifar_dir", r"C:UsersAdministratorPycharmProjectslearntest	ensodatacifar10\", "文件的目录")
    tf.app.flags.DEFINE_string("cifar_tfrecords", r".	mpcifar.tfrecords", "存入tfrecords的文件")
    
    class CifarRead(object):
        """读取二进制文件,写入tfrecords,读取tfrecords"""
        def __init__(self, filelist):
            # 文件列表
            self.filelist = filelist
    
            # 定义读取图片的一些属性,cifar下载的文件默认是32*32像素,彩色通道3,目标值1比特
            self.height = 32
            self.weight = 32
            self.channel = 3
            self.label_bytes = 1
            # 二进制文件每张图片的字节
            self.bytes = self.height * self.weight * self.channel + self.label_bytes
    
        def read_and_decode(self):
            # 1.构造文件队列
            file_queue = tf.train.string_input_producer(self.filelist)
    
            # 2.构造二进制文件读取器
            reader = tf.FixedLengthRecordReader(self.bytes)
            key, value = reader.read(file_queue)
    
            # 3.二进制文件内容解码
            label_image = tf.decode_raw(value, tf.uint8)
    
            # 4.将label_image中的特征值和目标值分割开来,cast目标值是0-9的整数所以转换成int32类型,特征值将用于计算,转换成float32类型
            label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
            image = tf.cast(tf.slice(label_image, [self.label_bytes], [self.bytes - self.label_bytes]), tf.float32)
            # print(label, image)  # 返回结果Tensor("Slice:0", shape=(1,), dtype=uint8) Tensor("Slice_1:0", shape=(3072,), dtype=uint8)
    
            # 5.可以对图片特征数据进行形状改变[3072] ==> [32, 32, 3]
            image_reshape = tf.reshape(image, [self.height, self.weight, self.channel])
    
            # 6.进行批处理
            image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=1, num_threads=1, capacity=10)
    
            return image_batch, label_batch
    
        def write_to_tfrecords(self, image_batch, label_batch):
            """将图片目标值的特征值存入tfrecords"""
            # 1.建立tfrecords存储器
            writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)
    
            # 2.循环将所有样本写入文件,没长图片样本都要构造example协议块
            for i in range(100):
                # 取出第i个图片数据特征值和目标值,这里调用了eval()方法,所有这个实例方法一定要在sess上下文管理器中调用
                image = image_batch[i].eval().tostring()
                label = int(label_batch[i].eval()[0])
    
                # 构造一个样本的example
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
                }))
    
                # 单独写入样本,不使用SerializeToString将会变成json数据
                writer.write(example.SerializeToString())
            # 关闭
            writer.close()
    
        def read_from_tfrecords(self):
            """读取tfrecords文件"""
            # 1.构造文件队列
            file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
    
            # 2.构造文件阅读器,读取内容example,value一个example的序列化
            reader = tf.TFRecordReader()
            key, value = reader.read(file_queue)
    
            # 3.解析example,这里的数据类型没有int32,可以在解码的时候再转换成int32
            features = tf.parse_single_example(value, features={
                "image":tf.FixedLenFeature([], tf.string),
                "label":tf.FixedLenFeature([], tf.int64),
            })
            # print(features["image"])  # 返回Tensor("ParseSingleExample/Squeeze_image:0", shape=(), dtype=string)
    
            # 4.解码内容,如果读取的内容格式是string需要解码,如果是int64,float32不需要解码
            image = tf.decode_raw(features["image"], tf.uint8)
    
            # 固定图片形状,方便批处理
            image_reshape = tf.reshape(image, [self.height, self.weight, self.channel])
            label = tf.cast(features["label"], tf.int32)
    
            # 进行批处理
            image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
    
            return image_batch, label_batch
    
    if __name__ == "__main__":
        # 构造文件列表
        file_name = os.listdir(FLAGS.cifar_dir)
        filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
        cf = CifarRead(filelist)
        # image_batch, label_batch = cf.read_and_decode()
        image_batch, label_batch = cf.read_from_tfrecords()
    
        # 开启会话
        with tf.Session() as sess:
    
            # 定义线程协调器
            coord = tf.train.Coordinator()
    
            # 开启读取文件的线程
            thd = tf.train.start_queue_runners(sess, coord=coord)
    
            # 存进tfrecords文件
            # cf.write_to_tfrecords(image_batch, label_batch)
    
            # 打印读取内容
            print(sess.run([image_batch, label_batch]))
    
            # 回收子线程
            coord.request_stop()
    
            coord.join(thd)
  • 相关阅读:
    性能优化方法
    JSM的topic和queue的区别
    关于分布式事务、两阶段提交协议、三阶提交协议
    大型网站系统与Java中间件实践读书笔记
    Kafka设计解析:Kafka High Availability
    kafka安装和部署
    String和intern()浅析
    JAVA中native方法调用
    Java的native方法
    happens-before俗解
  • 原文地址:https://www.cnblogs.com/kogmaw/p/12599423.html
Copyright © 2011-2022 走看看