zoukankan      html  css  js  c++  java
  • tensorflow2.0-------------读取二进制文件和tfrecords

    #-*- coding: utf-8 -*-
    # coding:unicode_escape
    #@Time : 2021/3/3 15:10
    #@Author : 杨晓
    #@File : binary_read.py
    #@Software: PyCharm
    import tensorflow as tf
    import os
    tf.compat.v1.disable_eager_execution()
    
    class Cifar(object):
    
    
        def __init__(self):
            self.height = 32
            self.width = 32
            self.channels = 3
            self.image_bytes = self.height * self.width * self.channels
            self.label_bytes = 1
            self.all_bytes = self.image_bytes + self.label_bytes
        # 读取二进制文件
        def read_binary(self):
            file_name = os.listdir("../tmp/data/cifar-10-batches-bin")
            # # 构造文件名列表
            file_list = [os.path.join("../tmp/data/cifar-10-batches-bin",file) for file in file_name if file[-3:] == "bin"]
            # 构造文件队列
            file_queue = tf.compat.v1.train.string_input_producer(file_list)
            # 读取并解码
            # 读取
            reader = tf.compat.v1.FixedLengthRecordReader(self.all_bytes)
            key,value = reader.read(file_queue)
            # 解码
            decoded = tf.compat.v1.decode_raw(value,tf.uint8)
            # 将目标值和特征值切片
    
            label = tf.slice(decoded,[0],[self.label_bytes])
            image = tf.slice(decoded,[self.label_bytes],[self.image_bytes])
            # 调整图片形状 Tensor("Reshape:0", shape=(3, 32, 32), dtype=uint8)
            image_reshape = tf.reshape(image,shape=[self.channels,self.height,self.width])
            # 将图片的顺序转换为 height width channels
            image_transpose = tf.transpose(image_reshape,[1,2,0])
            # 调整图片类型
            # image_cast = tf.cast(image_transpose,tf.float32)
            # 批处理
            label_batch,image_batch = tf.compat.v1.train.batch([label,image_transpose],batch_size=100,num_threads=1,capacity=100)
            print("image_bath:
    ",image_batch)
            # 开启会话
            with tf.compat.v1.Session() as sess:
                # 开启线程管理器
                coord = tf.compat.v1.train.Coordinator()
                threads = tf.compat.v1.train.start_queue_runners(sess=sess,coord=coord)
                label_value,image_value, = sess.run([label_batch,image_batch])
                print("label_new:
    ",label_value)
                print("image_new:
    ",image_value)
                # 回收子线程
                coord.request_stop()
                coord.join(threads=threads)
            return image_value,label_value
    
    
        def write_to_tfrecords(self,image_batch,label_batch):
            '''
            将样本的特征值和目标值写入rfrecords
            :return:
            '''
            with tf.compat.v1.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
                # 循环构造example对象,并写入文件
                for i in range(100):
                    image = image_batch[i].tostring()
                    label = label_batch[i][0]
                    print("records_image:
    ",image)
                    print("records_label:
    ",label)
                    example = tf.compat.v1.train.Example(features=tf.compat.v1.train.Features(feature={
                        "image": tf.compat.v1.train.Feature(bytes_list=tf.compat.v1.train.BytesList(value=[image])),
                        "label": tf.compat.v1.train.Feature(int64_list=tf.compat.v1.train.Int64List(value=[label])),
                    }))
                    # example.SerializeToString()
                    # 将序列化后的example写入文件
                    writer.write(example.SerializeToString())
            return None
    
        def read_tfrecords(self):
            # 构造文件队列
            file_queue = tf.compat.v1.train.string_input_producer(["cifar10.tfrecords"])
    
            # 读取与解码
            reader = tf.compat.v1.TFRecordReader()
            key,value = reader.read(file_queue)
            print("key:
    ",key)
            print("value:
    ",value)
            # 解析example
            # 解析example
            feature = tf.compat.v1.parse_single_example(value, features={
                "image": tf.compat.v1.FixedLenFeature([], tf.string),
                "label": tf.compat.v1.FixedLenFeature([], tf.int64)
            })
            image = feature["image"]
            label = feature["label"]
            print("read_tf_image:
    ",image)
            print("read_tf_label:
    ",label)
    
            # 解码
            image_decode = tf.compat.v1.decode_raw(image,tf.uint8)
            print("image_decode:
    ",image_decode)
            # 调整形状
            image_reshape = tf.reshape(image_decode,[self.height,self.width,self.channels])
            print("image_reshape:
    ",image_reshape)
            # 批处理构造队列
            image_batch, label_batch = tf.compat.v1.train.batch([image_reshape, label], batch_size=100, num_threads=2, capacity=100)
            print("image_batch:
    ",image_batch)
            print("label_batch:
    ",label_batch)
            with tf.compat.v1.Session() as sess:
                # 开启线程管理器
                coord = tf.compat.v1.train.Coordinator()
                threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)
                image_value, label_value = sess.run([image_batch, label_batch])
                print("image_value:
    ", image_value)
                print("label_value:
    ", label_value)
    
                # 回收资源
                coord.request_stop()
                coord.join(threads)
            return None
    
    if __name__ == '__main__':
        #获取文件名
        cifar = Cifar()
        # image_value,label_value = cifar.read_binary()
        # cifar.write_to_tfrecords(image_value,label_value)
        cifar.read_tfrecords()
    
  • 相关阅读:
    第六周总结
    《构建之法》读后感二
    移动端疫情展示
    第五周
    用python爬取疫情数据
    第四周
    疫情图表展示和时间查询
    wpf datagrid row height 行高自动计算使每行行高自适应文本
    c# 实现mysql事务
    c# 简单实现 插件模型 反射方式
  • 原文地址:https://www.cnblogs.com/yangxiao-/p/14476526.html
Copyright © 2011-2022 走看看