zoukankan      html  css  js  c++  java
  • 第十节 二进制文件读取

    import tensorflow as tf
    import os
    
    """
    uint8:存储,节约空间,比如在图片处理时,图片解码之前
    float32:矩阵计算,提高精度,比如在图片处理时,图片解码之后
    """
    # 训练数据连接:http://www.cs.toronto.edu/~kriz/cifar.html
    # 定义cifar的数据命令行参数
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string("cifar_dir", r"C:UsersAdministratorPycharmProjectslearntest	ensodatacifar10", "文件的目录")
    
    class CifarRead(object):
        """读取二进制文件,写入tfrecords,读取tfrecords"""
        def __init__(self, filelist):
            # 文件列表
            self.filelist = filelist
    
            # 定义读取图片的一些属性,cifar下载的文件默认是32*32像素,彩色通道3,目标值1比特
            self.height = 32
            self.weight = 32
            self.channel = 3
            self.label_bytes = 1
            # 二进制文件每张图片的字节
            self.bytes = self.height * self.weight * self.channel + self.label_bytes
    
        def read_and_decode(self):
            # 1.构造文件队列
            file_queue = tf.train.string_input_producer(self.filelist)
    
            # 2.构造二进制文件读取器
            reader = tf.FixedLengthRecordReader(self.bytes)
            key, value = reader.read(file_queue)
    
            # 3.二进制文件内容解码
            label_image = tf.decode_raw(value, tf.uint8)
    
            # 4.将label_image中的特征值和目标值分割开来,cast目标值是0-9的整数所以转换成int32类型,特征值将用于计算,转换成float32类型
            label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
            image = tf.cast(tf.slice(label_image, [self.label_bytes], [self.bytes - self.label_bytes]), tf.float32)
            # print(label, image)  # 返回结果Tensor("Slice:0", shape=(1,), dtype=uint8) Tensor("Slice_1:0", shape=(3072,), dtype=uint8)
    
            # 5.可以对图片特征数据进行形状改变[3072] ==> [32, 32, 3]
            image_reshape = tf.reshape(image, [self.height, self.weight, self.channel])
    
            # 6.进行批处理
            image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
    
            return image_batch, label_batch
    
    if __name__ == "__main__":
        # 构造文件列表
        file_name = os.listdir(FLAGS.cifar_dir)
        filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
        cf = CifarRead(filelist)
        image_batch, label_batch = cf.read_and_decode()
    
        # 开启会话
        with tf.Session() as sess:
            # 定义线程协调器
            coord = tf.train.Coordinator()
    
            # 开启读取文件的线程
            thd = tf.train.start_queue_runners(sess, coord=coord, start=True)
    
            # 打印读取内容
            print(sess.run([image_batch, label_batch]))
    
            # 回收子线程
            coord.request_stop()
            coord.join(thd)
  • 相关阅读:
    Palindrome Linked List 解答
    Word Break II 解答
    Array vs Linked List
    Reverse Linked List II 解答
    Calculate Number Of Islands And Lakes 解答
    Sqrt(x) 解答
    Find Median from Data Stream 解答
    Majority Element II 解答
    Binary Search Tree DFS Template
    188. Best Time to Buy and Sell Stock IV
  • 原文地址:https://www.cnblogs.com/kogmaw/p/12597980.html
Copyright © 2011-2022 走看看