zoukankan      html  css  js  c++  java
  • tensorflow 使用tfrecords创建自己数据集

    直接采用矩阵方式建立数据集见:https://www.cnblogs.com/WSX1994/p/10128338.html

    制作自己的数据集(使用tfrecords)

    为什么采用这个格式?

    TFRecords文件格式在图像识别中有很好的使用,其可以将二进制数据和标签数据(训练的类别标签)数据存储在同一个文件中,它可以在模型进行训练之前通过预处理步骤将图像转换为TFRecords格式,此格式最大的优点实践每幅输入图像和与之关联的标签放在同一个文件中.TFRecords文件是一种二进制文件,其不对数据进行压缩,所以可以被快速加载到内存中.格式不支持随机访问,因此它适合于大量的数据流,但不适用于快速分片或其他非连续存取。

    前戏:

    tf.train.Feature
    tf.train.Feature有三个属性为tf.train.bytes_list    tf.train.float_list    tf.train.int64_list,显然我们只需要根据上一步得到的值来设置tf.train.Feature的属性就可以了,如下所示:

    1 tf.train.Feature(int64_list=data_id)
    2 tf.train.Feature(bytes_list=data)

    tf.train.Features
    从名字来看,我们应该能猜出tf.train.Features是tf.train.Feature的复数,事实上tf.train.Features有属性为feature,这个属性的一般设置方法是传入一个字典,字典的key是字符串(feature名),而值是tf.train.Feature对象。因此,我们可以这样得到tf.train.Features对象:

    1 feature_dict = {
    2 "data_id": tf.train.Feature(int64_list=data_id),
    3 "data": tf.train.Feature(bytes_list=data)
    4 }
    5 features = tf.train.Features(feature=feature_dict)

    tf.train.Example
    终于到我们的主角了。tf.train.Example有一个属性为features,我们只需要将上一步得到的结果再次当做参数传进来即可。
    另外,tf.train.Example还有一个方法SerializeToString()需要说一下,这个方法的作用是把tf.train.Example对象序列化为字符串,因为我们写入文件的时候不能直接处理对象,需要将其转化为字符串才能处理。
    当然,既然有对象序列化为字符串的方法,那么肯定有从字符串反序列化到对象的方法,该方法是FromString(),需要传递一个tf.train.Example对象序列化后的字符串进去做为参数才能得到反序列化的对象。
    在我们这里,只需要构建tf.train.Example对象并序列化就可以了,这一步的代码为:

    1 example = tf.train.Example(features=features)
    2 example_str = example.SerializeToString()

    实例(高潮部分):

    首先看一下我们的文件夹路径:

    create_tfrecords.py中写我们的函数

    生成数据文件阶段代码如下:

     1 def creat_tf(imgpath):
     2     cwd = os.getcwd()  #获取当前路径
     3     classes = os.listdir(cwd + imgpath)  #获取到[1, 2]文件夹
     4     # 此处定义tfrecords文件存放
     5     writer = tf.python_io.TFRecordWriter("train.tfrecords")
     6     for index, name in enumerate(classes):   #循环获取俩文件夹(俩类别)
     7         class_path = cwd + imgpath + name + "/"
     8         if os.path.isdir(class_path):
     9             for img_name in os.listdir(class_path):
    10                 img_path = class_path + img_name
    11                 img = Image.open(img_path)
    12                 img = img.resize((224, 224))
    13                 img_raw = img.tobytes()
    14                 example = tf.train.Example(features=tf.train.Features(feature={
    15                     'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(name)])),
    16                     'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
    17                 }))
    18                 writer.write(example.SerializeToString())
    19                 print(img_name)
    20     writer.close()

    这段代码主要生成  train.tfrecords 文件。

    读取数据阶段代码如下:

     1 def read_and_decode(filename):
     2     # 根据文件名生成一个队列
     3     filename_queue = tf.train.string_input_producer([filename])
     4 
     5     reader = tf.TFRecordReader()
     6     _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
     7     features = tf.parse_single_example(serialized_example,
     8                                        features={
     9                                            'label': tf.FixedLenFeature([], tf.int64),
    10                                            'img_raw': tf.FixedLenFeature([], tf.string),
    11                                        })
    12 
    13     img = tf.decode_raw(features['img_raw'], tf.uint8)
    14     img = tf.reshape(img, [224, 224, 3])
    15     # 转换为float32类型,并做归一化处理
    16     img = tf.cast(img, tf.float32)  # * (1. / 255)
    17     label = tf.cast(features['label'], tf.int64)
    18     return img, label

    训练阶段我们获取数据的代码:

     1 images, labels = read_and_decode('./train.tfrecords')
     2 img_batch, label_batch = tf.train.shuffle_batch([images, labels],
     3                                                 batch_size=5,
     4                                                 capacity=392,
     5                                                 min_after_dequeue=200)
     6 init = tf.global_variables_initializer()
     7 with tf.Session() as sess:
     8     sess.run(init)
     9     coord = tf.train.Coordinator()  #线程协调器
    10     threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    11     # 训练部分代码--------------------------------
    12     IMG, LAB = sess.run([img_batch, label_batch])
    13     print(IMG.shape)
    14 
    15     #----------------------------------------------
    16     coord.request_stop()  # 协调器coord发出所有线程终止信号
    17     coord.join(threads) #把开启的线程加入主线程,等待threads结束

    总结(流程):

    1. 生成tfrecord文件
    2. 定义record reader解析tfrecord文件
    3. 构造一个批生成器(batcher
    4. 构建其他的操作
    5. 初始化所有的操作
    6. 启动QueueRunner

    备注:关于tf.train.Coordinator 详见:

    https://blog.csdn.net/dcrmg/article/details/79780331

    TensorFlow的Session对象是支持多线程的,可以在同一个会话(Session)中创建多个线程,并行执行。在Session中的所有线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止的时候, 队列必须能被正确地关闭。

    1. 调用 tf.train.slice_input_producer,从 本地文件里抽取tensor,准备放入Filename Queue(文件名队列)中;
    2. 调用 tf.train.batch,从文件名队列中提取tensor,使用单个或多个线程,准备放入文件队列;
    3. 调用 tf.train.Coordinator() 来创建一个线程协调器,用来管理之后在Session中启动的所有线程;
    4. 调用tf.train.start_queue_runners, 启动入队线程,由多个或单个线程,按照设定规则,把文件读入Filename Queue中。函数返回线程ID的列表,一般情况下,系统有多少个核,就会启动多少个入队线程(入队具体使用多少个线程在tf.train.batch中定义);
    5. 文件从 Filename Queue中读入内存队列的操作不用手动执行,由tf自动完成;
    6. 调用sess.run 来启动数据出列和执行计算;
    7. 使用 coord.should_stop()来查询是否应该终止所有线程,当文件队列(queue)中的所有文件都已经读取出列的时候,会抛出一个 OutofRangeError 的异常,这时候就应该停止Sesson中的所有线程了;
    8. 使用coord.request_stop()来发出终止所有线程的命令,使用coord.join(threads)把线程加入主线程,等待threads结束。
  • 相关阅读:
    MINIX文件系统
    Cmd Markdown 语法
    asp.net mvc 4 json大数据异常 提示JSON字符长度超出限制的异常[转载]
    echart 拖拽
    搭建django开发环境
    Django 1.11.7+django_pyodbc_azure-1.11.0.0+pyodbc 连接mssql 数据库
    二、PyCharm 创建Django 第一个项目
    一、Django 安装
    python 连接各类主流数据库简单示例【转载】
    Python 3.6 连接mssql数据库(pymssql 方式)
  • 原文地址:https://www.cnblogs.com/WSX1994/p/10954925.html
Copyright © 2011-2022 走看看