zoukankan      html  css  js  c++  java
  • [TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用

    前言

    是的,除了水报错文,我也来写点其他的。本文主要介绍Keras中以下三个函数的用法:

    fit()
    fit_generator()
    train_on_batch()
    当然,与上述三个函数相似的evaluate、predict、test_on_batch、predict_on_batch、evaluate_generator和predict_generator等就不详细说了,举一反三嘛。

    环境

    本文的代码是在以下环境下进行测试的:

    Windows 10
    Python 3.6
    TensorFlow 2.0 Alpha
    异同

    大家用Keras也就图个简单快捷,但是在享受简单快捷的时候,也常常需要些定制化需求,除了model.fit(),有时候model.fit_generator()和model.train_on_batch()也很重要。那么,这三个函数有什么异同呢?Adrian Rosebrock [1] 有如下总结:

    当你使用.fit()函数时,意味着如下两个假设:

    训练数据可以 完整地 放入到内存(RAM)里
    数据已经不需要再进行任何处理了
    这两个原因解释的非常好,之前我运行程序的时候,由于数据集太大(实际中的数据集显然不会都像 TensorFlow 官方教程里经常使用的 MNIST 数据集那样小),一次性加载训练数据到fit()函数里根本行不通:

    history = model.fit(train_data, train_label) // Bomb!!!
    1
    于是我想,能不能先加载一个batch训练,然后再加载一个batch,如此往复。于是我就注意到了fit_generator()函数。什么时候该使用fit_generator函数呢?Adrian Rosebrock 的总结道:

    内存不足以一次性加载整个训练数据的时候
    需要一些数据预处理(例如旋转和平移图片、增加噪音、扩大数据集等操作)
    在生成batch的时候需要更多的处理
    对于我自己来说,除了数据集太大的缘故之外,我需要在生成batch的时候,对输入数据进行padding,所以fit_generator()就派上了用场。下面介绍如何使用这三种函数。

    fit()函数

    fit()函数其实没什么好说的,大家在看TensorFlow教程的时候已经见识过了。此外插一句话,tf.data.Dataset对不规则的序列数据真是不友好。

    import tensorflow as tf

    model = tf.keras.models.Sequential([
    ... // 你的模型
    ])

    model.fit(train_x, // 训练输入
    train_y, // 训练标签
    epochs=5 // 训练5轮
    )
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    fit_generator()函数

    fit_generator()函数就比较重要了,也是本文讨论的重点。fit_generator()与fit()的主要区别就在一个generator上。之前,我们把整个训练数据都输入到fit()里,我们也不需要考虑batch的细节;现在,我们使用一个generator,每次生成一个batch送给fit_generator()训练。

    def generator(x, y, b_size):
    ... // 处理函数

    model.fit_generator(generator(train_x, train_y, batch_size),
    step_per_epochs=np.ceil(len(train_x)/batch_size),
    epochs=5
    )
    1
    2
    3
    4
    5
    6
    7
    从上述代码中,我们发现有两处不同:

    一个我们自定义的generator()函数,作为fit_generator()函数的第一个参数;
    fit_generator()函数的step_per_epochs参数
    自定义的generator()函数

    该函数即是我们数据的生成器,在训练的时候,fit_generator()函数会不断地执行generator()函数,获取一个个的batch。

    def generator(x, y, b_size):
    """Generates batch and batch and batch then feed into models.
    Args:
    x: input data;
    y: input labels;
    b_size: batch_size.
    Yield:
    (batch_x, batch_label): batched x and y.
    """
    while 1: // 死循环
    idx = ...
    batch_x = ...
    batch_y = ...
    ... // 任何你想要对这个`batch`中的数据执行的操作
    yield (batch_x, batch_y)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    需要注意的是,不要使用return或者exit。

    step_per_epochs参数

    由于generator()函数的循环没有终止条件,fit_generator也不知道一个epoch什么时候结束,所以我们需要手动指定step_per_epochs参数,一般的数值即为len(y)//batch_size。如果数据集大小不能整除batch_size,而且你打算使用最后一个batch的数据(该batch比batch_size要小),此时使用np.ceil(len(y)/batch_size)。

    keras.utils.Sequence类(2019年6月10日更新)

    除了写generator()函数,我们还可以利用keras.utils.Sequence类来生成batch。先扔代码:

    class Generator(keras.utils.Sequence):
    def __init__(self, x, y, b_size):
    self.x, self.y = x, y
    self.batch_size = b_size

    def __len__(self):
    return math.ceil(len(self.y)/self.batch_size

    def __getitem__(self, idx):
    b_x = self.x[idx*self.batch_size:(idx+1)*self.batch_size]
    b_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size]
    ... // 对`batch`的其余操作
    return np.array(b_x), np.array(b_y)

    def on_epoch_end(self):
    """执行完一个`epoch`之后,还可以做一些其他的事情!"""
    ...
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    我们首先定义__init__函数,读取训练集数据,然后定义__len__函数,返回一个epoch中需要执行的step数(此时在fit_generator()函数中就不需要指定steps_per_epoch参数了),最后定义__getitem__函数,返回一个batch的数据。代码如下:

    train_generator = Generator(train_x, train_y, batch_size)
    val_generator = Generator(val_x, val_y, batch_size)

    model.fit_generator(generator=train_generator,
    epochs=3197747,
    validation_data=val_generator
    )
    1
    2
    3
    4
    5
    6
    7
    根据官方 [2] 的说法,使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。总之,使用keras.utils.Sequence也是很方便的啦!

    train_on_batch()函数

    train_on_batch()函数接受一个batch的输入和标签,然后开始反向传播,更新参数等。大部分情况下你都不需要用到train_on_batch()函数,除非你有着充足的理由去定制化你的模型的训练流程。

    结语

    本文到此结束啦!也不知道讲清楚没有,如果有疑问或者有错误,还请读者不吝赐教啦!

    Reference

    A. Rosebrock. (December 24, 2018). How to use Keras fit and fit_generator (a hands-on tutorial). Retrieved from https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/
    tf.keras.utils.Sequence. (July 10, 2019). Retrieved from https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/utils/Sequence
    ---------------------

  • 相关阅读:
    Hadoop 学习笔记 (十) hadoop2.2.0 生产环境部署 HDFS HA Federation 含Yarn部署
    hadoop 2.x 安装包目录结构分析
    词聚类
    Hadoop 学习笔记 (十一) MapReduce 求平均成绩
    Hadoop 学习笔记 (十) MapReduce实现排序 全局变量
    Hadoop 学习笔记 (九) hadoop2.2.0 生产环境部署 HDFS HA部署方法
    Visual Studio Code 快捷键大全(Windows)
    Eclipse安装教程 ——史上最详细安装Java &Python教程说明
    jquery操作select(取值,设置选中)
    $.ajax 中的contentType
  • 原文地址:https://www.cnblogs.com/ly570/p/11198596.html
Copyright © 2011-2022 走看看