zoukankan      html  css  js  c++  java
  • 生成tfrecords格式数据和使用dataset API使用tfrecords数据

    TFRecords是TensorFlow中的设计的一种内置的文件格式,它是一种二进制文件,优点有如下几种:

    • 统一不同输入文件的框架
    • 它是更好的利用内存,更方便复制和移动(TFRecord压缩的二进制文件, protocal buffer序列化)
    • 是用于将二进制数据和标签(训练的类别标签)数据存储在同一个文件中

    一、将其他数据存储为TFRecords文件的时候,需要经过两个步骤:

    建立TFRecord存储器

      在tensorflow中使用下面语句来简历tfrecord存储器:

    tf.python_io.TFRecordWriter(path)

    path : 创建的TFRecords文件的路径

    方法: 

    • write(record):向文件中写入一个字符串记录(即一个样本)
    • close() : 在写入所有文件后,关闭文件写入器。

    注:此处的字符串为一个序列化的Example,通过Example.SerializeToString()来实现,它的作用是将Example中的map压缩为二进制,节约大量空间。

    构造每个样本的Example模块

    Example模块的定义如下:

    message Example {
      Features features = 1;
    };
    
    message Features {
      map<string, Feature> feature = 1;
    };
    
    message Feature {
      oneof kind {
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
      }
    };

    可以看到,Example中可以包括三种格式的数据:tf.int64,tf.float32和二进制类型。

    features是以键值对的形式保存的。示例代码如下:

    example = tf.train.Example(
                features=tf.train.Features(feature={
                    "label": tf.train.Feature(float_list=tf.train.FloatList(value=[string[1]])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                    'x1_offset':tf.train.Feature(float_list=tf.train.FloatList(value=[string[2]])),
                    'y1_offset': tf.train.Feature(float_list=tf.train.FloatList(value=[string[3]])),
                    'x2_offset': tf.train.Feature(float_list=tf.train.FloatList(value=[string[4]])),
                    'y2_offset': tf.train.Feature(float_list=tf.train.FloatList(value=[string[5]])),
                    'beta_det':tf.train.Feature(float_list=tf.train.FloatList(value=[string[6]])),
                    'beta_bbox':tf.train.Feature(float_list=tf.train.FloatList(value=[string[7]]))
                }))

    构造好了Example模块后,我们就可以将样本写入文件了:

    writer.write(example.SerializeToString())

    文件全部写入后不要忘记关闭文件写入器。

    二、创建好我们自己的tfrecords文件后,我们就可以在训练的时候使用它啦。tensorflow为我们提供了Dataset这个API以方便地使用tfrecords文件。

    首先,我们要定义一个解析tfrecords的函数,它用来将二进制文件解析为张量。示例代码如下:

    def pares_tf(example_proto):
        #定义解析的字典
        dics = {
            'label': tf.FixedLenFeature([], tf.float32),
            'img_raw': tf.FixedLenFeature([], tf.string),
            'x1_offset': tf.FixedLenFeature([], tf.float32),
            'y1_offset': tf.FixedLenFeature([], tf.float32),
            'x2_offset': tf.FixedLenFeature([], tf.float32),
            'y2_offset': tf.FixedLenFeature([], tf.float32),
            'beta_det': tf.FixedLenFeature([], tf.float32),
            'beta_bbox': tf.FixedLenFeature([], tf.float32)}
        #调用接口解析一行样本
        parsed_example = tf.parse_single_example(serialized=example_proto,features=dics)
        image = tf.decode_raw(parsed_example['img_raw'],out_type=tf.uint8)
        image = tf.reshape(image,shape=[12,12,3])
        #这里对图像数据做归一化
        image = (tf.cast(image,tf.float32)/255.0)
        label = parsed_example['label']
        label=tf.reshape(label,shape=[1])
        label = tf.cast(label,tf.float32)
        x1_offset=parsed_example['x1_offset']
        x1_offset = tf.reshape(x1_offset, shape=[1])
        y1_offset=parsed_example['y1_offset']
        y1_offset = tf.reshape(y1_offset, shape=[1])
        x2_offset=parsed_example['x2_offset']
        x2_offset = tf.reshape(x2_offset, shape=[1])
        y2_offset=parsed_example['y2_offset']
        y2_offset = tf.reshape(y2_offset, shape=[1])
        beta_det=parsed_example['beta_det']
        beta_det=tf.reshape(beta_det,shape=[1])
        beta_bbox=parsed_example['beta_bbox']
        beta_bbox=tf.reshape(beta_bbox,shape=[1])
    
        return image,label,x1_offset,y1_offset,x2_offset,y2_offset,beta_det,beta_bbox

    接下来,我们需要使用tf.data.TFRecordDataset(filenames)读入tfrecords文件。

    一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作。

    常用的Transformation有:map、batch、shuffle、repeat。

    map: 

      map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset

    batch:

      batch就是将多个元素组合成batch

    repeat:

      repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch

    shuffle:

      shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的大小。

    示例代码:

    dataset = tf.data.TFRecordDataset(filenames=[filename])
    dataset = dataset.map(pares_tf)
    dataset = dataset.batch(16).repeat(1)#整个序列只使用一次,每次使用16个样本组成一个批次

    现在这一个批次的样本做好了,如何将它取出以用于训练呢?答案是使用迭代器,在tensorflow中的语句如下:

    iterator = dataset.make_one_shot_iterator()

    所谓one_shot意味着只能从头到尾读取一次,那如何在每一个训练轮次中取出不同的样本呢?iterator的get_netxt()方法可以实现这一点。需要注意的是,这里使用get_next()得到的只是一个tensor,并不是一个具体的值,在训练的时候要使用这个值的话,我们需要在session里面来取得。

    使用dataset读取tfrecords文件的完整代码如下:

    def pares_tf(example_proto):
        #定义解析的字典
        dics = {
            'label': tf.FixedLenFeature([], tf.float32),
            'img_raw': tf.FixedLenFeature([], tf.string),
            'x1_offset': tf.FixedLenFeature([], tf.float32),
            'y1_offset': tf.FixedLenFeature([], tf.float32),
            'x2_offset': tf.FixedLenFeature([], tf.float32),
            'y2_offset': tf.FixedLenFeature([], tf.float32),
            'beta_det': tf.FixedLenFeature([], tf.float32),
            'beta_bbox': tf.FixedLenFeature([], tf.float32)}
        #调用接口解析一行样本
        parsed_example = tf.parse_single_example(serialized=example_proto,features=dics)
        image = tf.decode_raw(parsed_example['img_raw'],out_type=tf.uint8)
        image = tf.reshape(image,shape=[12,12,3])
        #这里对图像数据做归一化
        image = (tf.cast(image,tf.float32)/255.0)
        label = parsed_example['label']
        label=tf.reshape(label,shape=[1])
        label = tf.cast(label,tf.float32)
        x1_offset=parsed_example['x1_offset']
        x1_offset = tf.reshape(x1_offset, shape=[1])
        y1_offset=parsed_example['y1_offset']
        y1_offset = tf.reshape(y1_offset, shape=[1])
        x2_offset=parsed_example['x2_offset']
        x2_offset = tf.reshape(x2_offset, shape=[1])
        y2_offset=parsed_example['y2_offset']
        y2_offset = tf.reshape(y2_offset, shape=[1])
        beta_det=parsed_example['beta_det']
        beta_det=tf.reshape(beta_det,shape=[1])
        beta_bbox=parsed_example['beta_bbox']
        beta_bbox=tf.reshape(beta_bbox,shape=[1])
    
        return image,label,x1_offset,y1_offset,x2_offset,y2_offset,beta_det,beta_bbox
    
    dataset = tf.data.TFRecordDataset(filenames=[filename])
    dataset = dataset.map(pares_tf)
    dataset = dataset.batch(16).repeat(1)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    with tf.Session() as sess:
            
        img, label, x1_offset, y1_offset, x2_offset, y2_offset, beta_det, beta_bbox = sess.run(fetches=next_element)

     

  • 相关阅读:
    mac 通过SSH连接服务器aws和github
    Android开发 View与Activity的生命周期[转载]
    Android开发 APP闪退Fragment重叠泄露问题
    3月3日 一堆一堆事
    杭州.net俱乐部 新开qq群
    招聘 .net 开发工程师
    852009
    872009
    01背包和完全背包
    8142009
  • 原文地址:https://www.cnblogs.com/puheng/p/9576521.html
Copyright © 2011-2022 走看看