zoukankan      html  css  js  c++  java
  • 越来越清晰的TFRecord处理图片的步骤

    # 首先是模块的导入
    """
    os模块是处理文件夹用的
    PIL模块是用来处理图片的
    """
    import tensorflow as tf
    import os
    from PIL import Image
    
    path = "tensorflow_application/jpg"  # 这是上述文件结构的主文件夹路径
    filename = os.listdir(path)  # 作用是遍历path文件夹下的文件,返回的是001和002文件夹构成的一个列表
    writer = tf.python_io.TFRecordWriter("tensorflow_application/train.tfrecords")  # 将TFRecordWriter实例化,用于文件的写操作。其中的路径是tfrecords文件的存放路径,这个路径并不需要实现建立,代码会自动生成
    
    for name in filename:
        class_path = path + os.sep + name  # 得到每一类的路径,即001文件夹和002文件夹的路径,其中的os.sep返回的是一个符号,即'//',这是路径中的一个符号而已,起到连接作用,构成此文件夹的完整路径
        for img_name in os.listdir(class_path):
            img_path = class_path + os.sep + img_name  # 同上,得到此文件夹下的每一张图片的完整路径,用于后续的图片提取并处理
            img = Image.open(img_path)  # 取出图片
            img = img.resize((500, 500))  # 改变图片大小,大小视具体的网络要求而定,不同的网络对输入图片的大小并不完全相同。这里我暂且将图片变为500*500的大小
            img_raw = img.tobytes()  # 这里将图片矩阵变为字符串形式进行存储,因为TFRecords能够保存的只能是二进制数据,因此需要将数组转换为二进制形式
            # 下面是关键的步骤,将数据填入到Example协议内存块中,最终生成TFRecords文件。TFRecords文件就是通过一个包含着二进制文件的数据文件,将特征和标签进行保存便于TensorFlow读取
            """
            一个tf.train.Example,即Example协议内存块,包含着若干数据特征(Features),而Features
            中又包含着Feature字典。任何一个Feature中又包含着FloatList, Int64List或BytesList,本例
            中使用到了其中两种数据格式,即Int64List和BytesList,需要注意的是value后跟的值需要为
            列表形式,所以加上了方括号
            """
            example = tf.train.Example(
                features = tf.train.Features(
                    feature={
                        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[name])),
                        "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))))
                        }
            serialized = example.SerializeToString()  # 先将样本进行序列化操作
            writer.write(serialized)  # 对序列化操作后的变量进行写操作,即生成最终的tfrecords文件
    

      接下来需要做的便是读取生成的tfrecords文件,在神经网络中,需要将tfrecords文件中的image和label读取出来,然后将其传递给图。

    # 使用的模块还是tensorflow
    import tensorflow as tf
    
    filename = "tensorflow_application/train.tfrecords"  # 这是上面生成的tfrecords文件
    filename_queue = tf.train.string_input_producer([filenname])  # 建立一个队列,其中的参数为tfrecords文件的路径
    
    reader = tf.TFRecordReader()  # 实例化读操作,建立读取器
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    """
    通过parse_single_example解析器解析,将Example协议内存块解析为张量(Tensor),然后使用
    解码器tf.decode_raw解码
    """
    features = tf.parse_single_example(serialized_example, 
        features={
            "label": tf.FixedLenFeature([], tf.int64),
            "image": tf.FixedLenFeature([], tf.string)
            })
    img = tf.decode_raw(features["image"], tf.uint8)  # 使用tf.decode_raw解码
    img = tf.reshape(img, [500, 500, 3])  # 重构图片的大小为500*500*3
    
    img = tf.cast(img, tf.float32) * (1. / 128) - 0.5
    label = tf.cast(features["label"], tf.int32)
    
    """
    上面将img和label从tfrecords文件中读取了出来,但是如果需要将数据取出供
    图使用,还需要使用tf.train.shuffle_batch
    shuffle_batch的主要参数为:
    1. tensor: 入队队列,即上面得到的img和label,[img, label]
    2. batch_size: batch的大小
    3. capacity: 队列的最大容量
    4. num_threads: 线程数
    5. min_after_dequeue: 限制出队时队列中元素的最小个数
    """
    img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=1,
                                                    capacity=24, min_after_dequeue=1)  # 将得到的img_batch, label_batch传递给需要进行递归的数据即可
    

      原文链接:https://blog.csdn.net/cl2227619761/article/details/80107208

  • 相关阅读:
    windows下命令行
    利用border画三角形
    正则
    flex布局
    css笔记
    W3C标准
    SEO相关
    左边固定,右边自适应(解决方案)
    容错性测试的测试点
    Charles安装及使用教程
  • 原文地址:https://www.cnblogs.com/lzq116/p/12030087.html
Copyright © 2011-2022 走看看