zoukankan      html  css  js  c++  java
  • 二月五号博客

    今天学了TensorFlow文件读取操作

    一,读取图片文件

    def read_picture():
        """
        读取狗图片案例
        :return:
        """
        # 1、构造文件名队列
        # 构造文件名列表
        filename_list = os.listdir("./dog")
        # 给文件名加上路径
        file_list = [os.path.join("./dog/", i) for i in filename_list]
        # print("file_list:
    ", file_list)
        # print("filename_list:
    ", filename_list)
        file_queue = tf.train.string_input_producer(file_list)
    
        # 2、读取与解码
        # 读取
        reader = tf.WholeFileReader()
        key, value = reader.read(file_queue)
        print("key:
    ", key)
        print("value:
    ", value)
    
        # 解码
        image_decoded = tf.image.decode_jpeg(value)
        print("image_decoded:
    ", image_decoded)
    
        # 将图片缩放到同一个大小
        image_resized = tf.image.resize_images(image_decoded, [200, 200])
        print("image_resized_before:
    ", image_resized)
        # 更新静态形状
        image_resized.set_shape([200, 200, 3])
        print("image_resized_after:
    ", image_resized)
    
    
        # 3、批处理队列
        image_batch = tf.train.batch([image_resized], batch_size=100, num_threads=2, capacity=100)
        print("image_batch:
    ", image_batch)
    
        # 开启会话
        with tf.Session() as sess:
            # 开启线程
            # 构造线程协调器
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
            # 运行
            filename, sample, image, n_image = sess.run([key, value, image_resized, image_batch])
            print("filename:
    ", filename)
            print("sample:
    ", sample)
            print("image:
    ", image)
            print("n_image:
    ", n_image)
    
            coord.request_stop()
            coord.join(threads)
    
    
        return None

    二,读取二进制文件

    class Cifar():
    
        def __init__(self):
    
            # 设置图像大小
            self.height = 32
            self.width = 32
            self.channel = 3
    
            # 设置图像字节数
            self.image = self.height * self.width * self.channel
            self.label = 1
            self.sample = self.image + self.label
    
    
        def read_binary(self):
            """
            读取二进制文件
            :return:
            """
            # 1、构造文件名队列
            filename_list = os.listdir("./cifar-10-batches-bin")
            # print("filename_list:
    ", filename_list)
            file_list = [os.path.join("./cifar-10-batches-bin/", i) for i in filename_list if i[-3:]=="bin"]
            # print("file_list:
    ", file_list)
            file_queue = tf.train.string_input_producer(file_list)
    
            # 2、读取与解码
            # 读取
            reader = tf.FixedLengthRecordReader(self.sample)
            # key文件名 value样本
            key, value = reader.read(file_queue)
    
            # 解码
            image_decoded = tf.decode_raw(value, tf.uint8)
            print("image_decoded:
    ", image_decoded)
    
            # 切片操作
            label = tf.slice(image_decoded, [0], [self.label])
            image = tf.slice(image_decoded, [self.label], [self.image])
            print("label:
    ", label)
            print("image:
    ", image)
    
            # 调整图像的形状
            image_reshaped = tf.reshape(image, [self.channel, self.height, self.width])
            print("image_reshaped:
    ", image_reshaped)
    
            # 三维数组的转置
            image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
            print("image_transposed:
    ", image_transposed)
    
            # 3、构造批处理队列
            image_batch, label_batch = tf.train.batch([image_transposed, label], batch_size=100, num_threads=2, capacity=100)
    
            # 开启会话
            with tf.Session() as sess:
    
                # 开启线程
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
                label_value, image_value = sess.run([label_batch, image_batch])
                print("label_value:
    ", label_value)
                print("image:
    ", image_value)
    
                coord.request_stop()
                coord.join(threads)
    
            return image_value, label_value

    三,读取TFRecords文件

        def read_tfrecords(self):
            """
            读取TFRecords文件
            :return:
            """
            # 1、构造文件名队列
            file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])
    
            # 2、读取与解码
            # 读取
            reader = tf.TFRecordReader()
            key, value = reader.read(file_queue)
    
            # 解析example
            feature = tf.parse_single_example(value, features={
                "image": tf.FixedLenFeature([], tf.string),
                "label": tf.FixedLenFeature([], tf.int64)
            })
            image = feature["image"]
            label = feature["label"]
            print("read_tf_image:
    ", image)
            print("read_tf_label:
    ", label)
    
            # 解码
            image_decoded = tf.decode_raw(image, tf.uint8)
            print("image_decoded:
    ", image_decoded)
            # 图像形状调整
            image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel])
            print("image_reshaped:
    ", image_reshaped)
    
            # 3、构造批处理队列
            image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
            print("image_batch:
    ", image_batch)
            print("label_batch:
    ", label_batch)
    
            # 开启会话
            with tf.Session() as sess:
    
                # 开启线程
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
                image_value, label_value = sess.run([image_batch, label_batch])
                print("image_value:
    ", image_value)
                print("label_value:
    ", label_value)
    
                # 回收资源
                coord.request_stop()
                coord.join(threads)
    
            return None
  • 相关阅读:
    MongDB简单介绍
    Docker的简单介绍
    maven简单介绍
    粗谈Springboot框架,众所周知Springboot是有spring推出的微服务框架,什么是微服务框架呢!
    Springboot打包问题,打包的话是通过
    SpringBoot注解及swagger注解使用及规范
    properties配置
    日志配置
    c++几个面试题
    c++四种强制类型转化的区别
  • 原文地址:https://www.cnblogs.com/goubb/p/12267256.html
Copyright © 2011-2022 走看看