zoukankan      html  css  js  c++  java
  • TensorFlow读写数据

    前言

    只有光头才能变强。

    文本已收录至我的GitHub仓库,欢迎Star:https://github.com/ZhongFuCheng3y/3y

    回顾前面:

    众所周知,要训练出一个模型,首先我们得有数据。我们第一个例子中,直接使用dataset的api去加载mnist的数据。(minst的数据要么我们是提前下载好,放在对应的目录上,要么就根据他给的url直接从网上下载)。

    一般来说,我们使用TensorFlow是从TFRecord文件中读取数据的。

    TFRecord 文件格式是一种面向记录的简单二进制格式,很多 TensorFlow 应用采用此格式来训练数据

    所以,这篇文章来聊聊怎么读取TFRecord文件的数据。

    一、入门对数据集的数据进行读和写

    首先,我们来体验一下怎么造一个TFRecord文件,怎么从TFRecord文件中读取数据,遍历(消费)这些数据。

    1.1 造一个TFRecord文件

    现在,我们还没有TFRecord文件,我们可以自己简单写一个:

    def write_sample_to_tfrecord():
        gmv_values = np.arange(10)
        click_values = np.arange(10)
        label_values = np.arange(10)
    
        with tf.python_io.TFRecordWriter("/Users/zhongfucheng/data/fashin/demo.tfrecord", options=None) as writer:
            for _ in range(10):
                feature_internal = {
                    "gmv": tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])),
                    "click": tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]]))
                }
                features_extern = tf.train.Features(feature=feature_internal)
    
                # 使用tf.train.Example将features编码数据封装成特定的PB协议格式
                # example = tf.train.Example(features=tf.train.Features(feature=features_extern))
                example = tf.train.Example(features=features_extern)
    
                # 将example数据系列化为字符串
                example_str = example.SerializeToString()
    
                # 将系列化为字符串的example数据写入协议缓冲区
                writer.write(example_str)
    
    
    if __name__ == '__main__':
        write_sample_to_tfrecord()
    

    我相信大家代码应该是能够看得懂的,其实就是分了几步:

    • 生成TFRecord Writer
    • tf.train.Feature生成协议信息
    • 使用tf.train.Example将features编码数据封装成特定的PB协议格式
    • 将example数据系列化为字符串
    • 将系列化为字符串的example数据写入协议缓冲区

    参考资料:

    ok,现在我们就有了一个TFRecord文件啦。

    1.2 读取TFRecord文件

    • 其实就是通过tf.data.TFRecordDataset这个api来读取到TFRecord文件,生成处dataset对象

    • 对dataset进行处理(shape处理,格式处理...等等)

    • 使用迭代器对dataset进行消费(遍历)

    demo代码如下:

    import tensorflow as tf
    
    
    def read_tensorflow_tfrecord_files():
        # 定义消费缓冲区协议的parser,作为dataset.map()方法中传入的lambda:
        def _parse_function(single_sample):
            features = {
                "gmv": tf.FixedLenFeature([1], tf.float32),
                "click": tf.FixedLenFeature([1], tf.int64),  # ()或者[]没啥影响
                "label": tf.FixedLenFeature([1], tf.int64)
            }
            parsed_features = tf.parse_single_example(single_sample, features=features)
    
            # 对parsed 之后的值进行cast.
            gmv = tf.cast(parsed_features["gmv"], tf.float64)
            click = tf.cast(parsed_features["click"], tf.float64)
            label = tf.cast(parsed_features["label"], tf.float64)
    
            return gmv, click, label
    
        # 开始定义dataset以及解析tfrecord格式
        filenames = tf.placeholder(tf.string, shape=[None])
    
        # 定义dataset 和 一些列trasformation method
        dataset = tf.data.TFRecordDataset(filenames)
        parsed_dataset = dataset.map(_parse_function)  # 消费缓冲区需要定义在dataset 的map 函数中
        batchd_dataset = parsed_dataset.batch(3)
    
        # 创建Iterator
        sample_iter = batchd_dataset.make_initializable_iterator()
        # 获取next_sample
        gmv, click, label = sample_iter.get_next()
        training_filenames = [
            "/Users/zhongfucheng/data/fashin/demo.tfrecord"]
        with tf.Session() as session:
            # 初始化带参数的Iterator
            session.run(sample_iter.initializer, feed_dict={filenames: training_filenames})
            # 读取文件
            print(session.run(gmv))
    
    
    if __name__ == '__main__':
        read_tensorflow_tfrecord_files()
    
    

    无意外的话,我们可以输出这样的结果:

    [[0.]
     [1.]
     [2.]]
    

    ok,现在我们已经大概知道怎么写一个TFRecord文件,以及怎么读取TFRecord文件的数据,并且消费这些数据了。

    二、epoch和batchSize术语解释

    我在学习TensorFlow翻阅资料时,经常看到一些机器学习的术语,由于自己没啥机器学习的基础,所以很多时候看到一些专业名词就开始懵逼了。

    2.1epoch

    当一个完整的数据集通过了神经网络一次并且返回了一次,这个过程称为一个epoch

    这可能使我们跟dataset.repeat()方法联系起来,这个方法可以使当前数据集重复一遍。比如说,原有的数据集是[1,2,3,4,5],如果我调用dataset.repeat(2)的话,那么我们的数据集就变成了[1,2,3,4,5],[1,2,3,4,5]

    • 所以会有个说法:假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch

    2.2batchSize

    一般来说我们的数据集都是比较大的,无法一次性将整个数据集的数据喂进神经网络中,所以我们会将数据集分成好几个部分。每次喂多少条样本进神经网络,这个叫做batchSize。

    在TensorFlow也提供了方法给我们设置:dataset.batch(),在API中是这样介绍batchSize的:

    representing the number of consecutive elements of this dataset to combine in a single batch
    

    我们一般在每次训练之前,会将整个数据集的顺序打乱,提高我们模型训练的效果。这里我们用到的api是:dataset.shffle();

    三、再来聊聊dataset

    我从官网的介绍中截了一个dataset的方法图(部分):

    dataset的方法图

    dataset的功能主要有以下三种:

    • 创建dataset实例
      • 通过文件创建(比如TFRecord)
      • 通过内存创建
    • 对数据集的数据进行变换
      • 比如上面的batch(),常见的map(),flat_map(),zip(),repeat()等等
      • 文档中一般都有给出例子,跑一下一般就知道对应的意思了。
    • 创建迭代器,遍历数据集的数据

    3.1 聊聊迭代器

    迭代器可以分为四种:

    • 单次。对数据集进行一次迭代,不支持参数化
    • 可初始化迭代
      • 使用前需要进行初始化,支持传入参数。面向的是同一个DataSet
    • 可重新初始化:同一个Iterator从不同的DataSet中读取数据
      • DataSet的对象具有相同的结构,可以使用tf.data.Iterator.from_structure来进行初始化
      • 问题:每次 Iterator 切换时,数据都从头开始打印了
    • 可馈送(也是通过对象相同的结果来创建的迭代器)
      • 可让您在两个数据集之间切换的可馈送迭代器
      • 通过一个string handler来实现。
      • 可馈送的 Iterator 在不同的 Iterator 切换的时候,可以做到不从头开始

    简单总结:

    • 1、 单次 Iterator ,它最简单,但无法重用,无法处理数据集参数化的要求。
    • 2、 可以初始化的 Iterator ,它可以满足 Dataset 重复加载数据,满足了参数化要求。
    • 3、可重新初始化的 Iterator,它可以对接不同的 Dataset,也就是可以从不同的 Dataset 中读取数据。
    • 4、可馈送的 Iterator,它可以通过 feeding 的方式,让程序在运行时候选择正确的 Iterator,它和可重新初始化的 Iterator 不同的地方就是它的数据在不同的 Iterator 切换时,可以做到不重头开始读取数据

    string handler(可馈送的 Iterator)这种方式是最常使用的,我当时也写了一个Demo来使用了一下,代码如下:

    def read_tensorflow_tfrecord_files():
        # 开始定义dataset以及解析tfrecord格式.
        train_filenames = tf.placeholder(tf.string, shape=[None])
        vali_filenames = tf.placeholder(tf.string, shape=[None])
    
        # 加载train_dataset   batch_inputs这个方法每个人都不一样的,这个方法我就不给了。
        train_dataset = batch_inputs([
            train_filenames], batch_size=5, type=False,
            num_epochs=2, num_preprocess_threads=3)
        # 加载validation_dataset  batch_inputs这个方法每个人都不一样的,这个方法我就不给了。
        validation_dataset = batch_inputs([vali_filenames
                                           ], batch_size=5, type=False,
                                          num_epochs=2, num_preprocess_threads=3)
    
        # 创建出string_handler()的迭代器(通过相同数据结构的dataset来构建)
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_dataset.output_types, train_dataset.output_shapes)
    
        # 有了迭代器就可以调用next方法了。
        itemid = iterator.get_next()
    
        # 指定哪种具体的迭代器,有单次迭代的,有初始化的。
        training_iterator = train_dataset.make_initializable_iterator()
        validation_iterator = validation_dataset.make_initializable_iterator()
    
        # 定义出placeholder的值
        training_filenames = [
            "/Users/zhongfucheng/tfrecord_test/data01aa"]
        validation_filenames = ["/Users/zhongfucheng/tfrecord_validation/part-r-00766"]
    
        with tf.Session() as sess:
            # 初始化迭代器
            training_handle = sess.run(training_iterator.string_handle())
            validation_handle = sess.run(validation_iterator.string_handle())
    
            for _ in range(2):
                sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames})
                print("this is training iterator ----")
    
                for _ in range(5):
                    print(sess.run(itemid, feed_dict={handle: training_handle}))
    
                sess.run(validation_iterator.initializer,
                         feed_dict={vali_filenames: validation_filenames})
    
                print("this is validation iterator ")
                for _ in range(5):
                    print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle}))
    
    
    if __name__ == '__main__':
        read_tensorflow_tfrecord_files()
    
    

    参考资料:

    3.2 dataset参考资料

    在翻阅资料时,发现写得不错的一些博客:

    最后

    乐于输出干货的Java技术公众号:Java3y。公众号内有200多篇原创技术文章、海量视频资源、精美脑图,不妨来关注一下!

    下一篇文章打算讲讲如何理解axis~

    帅的人都关注了

    觉得我的文章写得不错,不妨点一下

  • 相关阅读:
    FPM
    Docker记录
    阿里云ECS发送企业邮件
    git操作
    vscode+vagrant+xdebug调试
    Spring Security开发安全的REST服务
    559. Maximum Depth of N-ary Tree
    《算法图解》之散列表
    766. Toeplitz Matrix
    893. Groups of Special-Equivalent Strings
  • 原文地址:https://www.cnblogs.com/Java3y/p/10543426.html
Copyright © 2011-2022 走看看