zoukankan      html  css  js  c++  java
  • tf.data

    以往的TensorFLow模型数据的导入方法可以分为两个主要方法,一种是使用feed_dict另外一种是使用TensorFlow中的Queues。前者使用起来比较灵活,可以利用Python处理各种输入数据,劣势也比较明显,就是程序运行效率较低;后面一种方法的效率较高,但是使用起来较为复杂,灵活性较差。

    Dataset作为新的API,比以上两种方法的速度都快,并且使用难度要远远低于使用Queues。tf.data中包含了两个用于TensorFLow程序的接口:Dataset和Iterator。

    Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道,原因如下:


    Dataset

    Dataset表示一个元素的集合,可以看作函数式编程中的 lazy list, 元素是tensor tuple。创建Dataset的方式可以分为两种,分别是:

    Source

    Apply transformation
    Source
    这里 source 指的是从tf.Tensor对象创建Dataset,常见的方法又如下几种:

    tf.data.Dataset.from_tensors((features, labels))
    tf.data.Dataset.from_tensor_slices((features, labels))
    tf.data.TextLineDataset(filenames)
    tf.data.TFRecordDataset(filenames)

    作用分别为:

      1.从一个tensor tuple创建一个单元素的dataset;

      2.从一个tensor tuple创建一个包含多个元素的dataset;

      3.读取一个文件名列表,将每个文件中的每一行作为一个元素,构成一个dataset;

      4.读取硬盘中的TFRecord格式文件,构造dataset。

    Apply transformation

    第二种方法就是通过转化已有的dataset来得到新的dataset,TensorFLow tf.data.Dataset支持很多中变换,在这里介绍常见的几种:

    dataset.map(lambda x: tf.decode_jpeg(x))
    dataset.repeat(NUM_EPOCHS)    
    dataset.batch(BATCH_SIZE)

    以上三种方式分别表示了:使用map对dataset中的每个元素进行处理,这里的例子是对图片数据进行解码;将dataset重复一定数目的次数用于多个epoch的训练;将原来的dataset中的元素按照某个数量叠在一起,生成mini batch。

    将以上代码组合起来,我们可以得到一个常用的代码片段:

    # 从一个文件名列表读取 TFRecord 构成 dataset
    dataset = TFRecordDataset(["file1.tfrecord", "file2.tfrecord"])
    # 处理 string,将 string 转化为 tf.Tensor 对象
    dataset = dataset.map(lambda record: tf.parse_single_example(record))
    # buffer 大小设置为 10000,打乱 dataset
    dataset = dataset.shuffle(10000)
    # dataset 将被用来训练 100 个 epoch
    dataset = dataset.repeat(100)
    # 设置 batch size 为 128
    dataset = dataset.batch(128)

    Iterator

    定义好了数据集以后可以通过Iterator接口来访问数据集中的tensor tuple,iterator保持了数据在数据集中的位置,提供了访问数据集中数据的方法。

    可以通过调用 dataset 的 make iterator 方法来构建 iterator。

    替换了place_holder,直接在原来开始的x,y处使用.get_next(),然后在sess.run时加个while true,在try里面放sess.run,exception 放OutofRangeError:

    X, y = dataset.get_next()
    
    while True:
        try:
            sess.run(accuracy)
        except tf.errors.OutOfRangeError:
            break

    API 支持以下四种 iterator,复杂程度递增:

    • one-shot
    • initializable
    • reinitializable
    • feedable

    one-shot

    one-shot iterator 谁最简单的一种 iterator,仅支持对整个数据集访问一遍,不需要显式的初始化。one-shot iterator 不支参数化。以下代码使用tf.data.Dataset.range生成数据集,作用与 python 中的 range 类似。

    dataset = tf.data.Dataset.range(100)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    for i in range(100):
      value = sess.run(next_element)
      assert i == value

    initializable

    Initializable iterator 要求在使用之前显式的通过调用iterator.initializer操作初始化,这使得在定义数据集时可以结合tf.placeholder传入参数,如:

    max_value = tf.placeholder(tf.int64, shape=[])
    dataset = tf.data.Dataset.range(max_value)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    
    # Initialize an iterator over a dataset with 10 elements.
    sess.run(iterator.initializer, feed_dict={max_value: 10})
    for i in range(10):
      value = sess.run(next_element)
      assert i == value
    
    # Initialize the same iterator over a dataset with 100 elements.
    sess.run(iterator.initializer, feed_dict={max_value: 100})
    for i in range(100):
      value = sess.run(next_element)
      assert i == value

    reinitializable

    reinitializable iterator 可以被不同的 dataset 对象初始化,比如对于训练集进行了shuffle的操作,对于验证集则没有处理,通常这种情况会使用两个具有相同结构的dataset对象,如:

    # Define training and validation datasets with the same structure.
    training_dataset = tf.data.Dataset.range(100).map(
        lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
    validation_dataset = tf.data.Dataset.range(50)
    
    # A reinitializable iterator is defined by its structure. We could use the
    # `output_types` and `output_shapes` properties of either `training_dataset`
    # or `validation_dataset` here, because they are compatible.
    iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                       training_dataset.output_shapes)
    next_element = iterator.get_next()
    
    training_init_op = iterator.make_initializer(training_dataset)
    validation_init_op = iterator.make_initializer(validation_dataset) # 如果后面初始化的是这个,那么就将循环这个数据集
    
    # Run 20 epochs in which the training dataset is traversed, followed by the
    # validation dataset.
    for _ in range(20):
      # Initialize an iterator over the training dataset.
      sess.run(training_init_op)
      for _ in range(100):
        sess.run(next_element)
    
      # Initialize an iterator over the validation dataset.
      sess.run(validation_init_op) # 替换init_op,相当于替换数据集
      for _ in range(50):
        sess.run(next_element)

    feedable

    feedable iterator 可以通过和tf.placeholder结合在一起,同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。它提供了与 reinitilizable iterator 类似的功能,并且在切换数据集的时候不需要在开始的时候初始化iterator,还是上面的例子,通过tf.data.Iterator.from_string_handle来定义一个 feedable iterator,达到切换数据集的目的:

    # Define training and validation datasets with the same structure.
    training_dataset = tf.data.Dataset.range(100).map(
        lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
    validation_dataset = tf.data.Dataset.range(50)
    
    # A feedable iterator is defined by a handle placeholder and its structure. We
    # could use the `output_types` and `output_shapes` properties of either
    # `training_dataset` or `validation_dataset` here, because they have
    # identical structure.
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, training_dataset.output_types, training_dataset.output_shapes)
    next_element = iterator.get_next()
    
    # You can use feedable iterators with a variety of different kinds of iterator
    # (such as one-shot and initializable iterators).
    training_iterator = training_dataset.make_one_shot_iterator()
    validation_iterator = validation_dataset.make_initializable_iterator()
    
    # The `Iterator.string_handle()` method returns a tensor that can be evaluated
    # and used to feed the `handle` placeholder.
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    
    # Loop forever, alternating between training and validation.
    while True:
      # Run 200 steps using the training dataset. Note that the training dataset is
      # infinite, and we resume from where we left off in the previous `while` loop
      # iteration.
      for _ in range(200):
        sess.run(next_element, feed_dict={handle: training_handle})
    
      # Run one pass over the validation dataset.
      sess.run(validation_iterator.initializer)
      for _ in range(50):
        sess.run(next_element, feed_dict={handle: validation_handle})

    使用实例: 

    def get_encodes(x):
        # x is `batch_size` of lines, each of which is a json object
        samples = [json.loads(l) for l in x]
        text = [s['fact'] for s in samples]
        # get a client from available clients
        bc_client = bc_clients.pop()
        features = bc_client.encode(text)
        # after use, put it back
        bc_clients.append(bc_client)
        labels = [0 for _ in text]
        return features, labels
    
    
    data_node = (tf.data.TextLineDataset(train_fp).batch(batch_size)
                 .map(lambda x: tf.py_func(get_encodes, [x], [tf.float32, tf.int64], name='bert_client'), num_parallel_calls=num_parallel_calls)
                 .map(lambda x, y: {'feature': x, 'label': y})
                 .make_one_shot_iterator().get_next())
  • 相关阅读:
    关于QQ秀
    c#重点知识解答(五) 选择自 masterall 的 Blog
    c#.net常用函数和方法集 选择自 fineflyak 的 Blog
    JavaScript 中 substr 和 substring的区别
    C#重点知识详解(二) 选择自 masterall 的 Blog
    c#重点知识详解(六) 选择自 masterall 的 Blog
    转:一个男孩的自白
    win2003端口映射2003的路由与远程访问,做端口映射(转)
    渗透笔记(转载)
    win下配置的Apache+PHP+MySQL绿色版本(转)
  • 原文地址:https://www.cnblogs.com/callyblog/p/10169860.html
Copyright © 2011-2022 走看看