zoukankan      html  css  js  c++  java
  • [TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用

    前言

    是的,除了水报错文,我也来写点其他的。本文主要介绍Keras中以下三个函数的用法:

    fit()
    fit_generator()
    train_on_batch()
    当然,与上述三个函数相似的evaluate、predict、test_on_batch、predict_on_batch、evaluate_generator和predict_generator等就不详细说了,举一反三嘛。

    环境

    本文的代码是在以下环境下进行测试的:

    Windows 10
    Python 3.6
    TensorFlow 2.0 Alpha
    异同

    大家用Keras也就图个简单快捷,但是在享受简单快捷的时候,也常常需要些定制化需求,除了model.fit(),有时候model.fit_generator()和model.train_on_batch()也很重要。那么,这三个函数有什么异同呢?Adrian Rosebrock [1] 有如下总结:

    当你使用.fit()函数时,意味着如下两个假设:

    训练数据可以 完整地 放入到内存(RAM)里
    数据已经不需要再进行任何处理了
    这两个原因解释的非常好,之前我运行程序的时候,由于数据集太大(实际中的数据集显然不会都像 TensorFlow 官方教程里经常使用的 MNIST 数据集那样小),一次性加载训练数据到fit()函数里根本行不通:

    history = model.fit(train_data, train_label) // Bomb!!!
    1
    于是我想,能不能先加载一个batch训练,然后再加载一个batch,如此往复。于是我就注意到了fit_generator()函数。什么时候该使用fit_generator函数呢?Adrian Rosebrock 的总结道:

    内存不足以一次性加载整个训练数据的时候
    需要一些数据预处理(例如旋转和平移图片、增加噪音、扩大数据集等操作)
    在生成batch的时候需要更多的处理
    对于我自己来说,除了数据集太大的缘故之外,我需要在生成batch的时候,对输入数据进行padding,所以fit_generator()就派上了用场。下面介绍如何使用这三种函数。

    fit()函数

    fit()函数其实没什么好说的,大家在看TensorFlow教程的时候已经见识过了。此外插一句话,tf.data.Dataset对不规则的序列数据真是不友好。

    import tensorflow as tf

    model = tf.keras.models.Sequential([
    ... // 你的模型
    ])

    model.fit(train_x, // 训练输入
    train_y, // 训练标签
    epochs=5 // 训练5轮
    )
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    fit_generator()函数

    fit_generator()函数就比较重要了,也是本文讨论的重点。fit_generator()与fit()的主要区别就在一个generator上。之前,我们把整个训练数据都输入到fit()里,我们也不需要考虑batch的细节;现在,我们使用一个generator,每次生成一个batch送给fit_generator()训练。

    def generator(x, y, b_size):
    ... // 处理函数

    model.fit_generator(generator(train_x, train_y, batch_size),
    step_per_epochs=np.ceil(len(train_x)/batch_size),
    epochs=5
    )
    1
    2
    3
    4
    5
    6
    7
    从上述代码中,我们发现有两处不同:

    一个我们自定义的generator()函数,作为fit_generator()函数的第一个参数;
    fit_generator()函数的step_per_epochs参数
    自定义的generator()函数

    该函数即是我们数据的生成器,在训练的时候,fit_generator()函数会不断地执行generator()函数,获取一个个的batch。

    def generator(x, y, b_size):
    """Generates batch and batch and batch then feed into models.
    Args:
    x: input data;
    y: input labels;
    b_size: batch_size.
    Yield:
    (batch_x, batch_label): batched x and y.
    """
    while 1: // 死循环
    idx = ...
    batch_x = ...
    batch_y = ...
    ... // 任何你想要对这个`batch`中的数据执行的操作
    yield (batch_x, batch_y)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    需要注意的是,不要使用return或者exit。

    step_per_epochs参数

    由于generator()函数的循环没有终止条件,fit_generator也不知道一个epoch什么时候结束,所以我们需要手动指定step_per_epochs参数,一般的数值即为len(y)//batch_size。如果数据集大小不能整除batch_size,而且你打算使用最后一个batch的数据(该batch比batch_size要小),此时使用np.ceil(len(y)/batch_size)。

    keras.utils.Sequence类(2019年6月10日更新)

    除了写generator()函数,我们还可以利用keras.utils.Sequence类来生成batch。先扔代码:

    class Generator(keras.utils.Sequence):
    def __init__(self, x, y, b_size):
    self.x, self.y = x, y
    self.batch_size = b_size

    def __len__(self):
    return math.ceil(len(self.y)/self.batch_size

    def __getitem__(self, idx):
    b_x = self.x[idx*self.batch_size:(idx+1)*self.batch_size]
    b_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size]
    ... // 对`batch`的其余操作
    return np.array(b_x), np.array(b_y)

    def on_epoch_end(self):
    """执行完一个`epoch`之后,还可以做一些其他的事情!"""
    ...
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    我们首先定义__init__函数,读取训练集数据,然后定义__len__函数,返回一个epoch中需要执行的step数(此时在fit_generator()函数中就不需要指定steps_per_epoch参数了),最后定义__getitem__函数,返回一个batch的数据。代码如下:

    train_generator = Generator(train_x, train_y, batch_size)
    val_generator = Generator(val_x, val_y, batch_size)

    model.fit_generator(generator=train_generator,
    epochs=3197747,
    validation_data=val_generator
    )
    1
    2
    3
    4
    5
    6
    7
    根据官方 [2] 的说法,使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。总之,使用keras.utils.Sequence也是很方便的啦!

    train_on_batch()函数

    train_on_batch()函数接受一个batch的输入和标签,然后开始反向传播,更新参数等。大部分情况下你都不需要用到train_on_batch()函数,除非你有着充足的理由去定制化你的模型的训练流程。

    结语

    本文到此结束啦!也不知道讲清楚没有,如果有疑问或者有错误,还请读者不吝赐教啦!

    Reference

    A. Rosebrock. (December 24, 2018). How to use Keras fit and fit_generator (a hands-on tutorial). Retrieved from https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/
    tf.keras.utils.Sequence. (July 10, 2019). Retrieved from https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/utils/Sequence
    ---------------------

  • 相关阅读:
    二分查找
    Linux下查找目录中所有文件中含有某个字符串,并且只打印出文件名
    编码规范
    Cookie和Session的选择,以及如何解决分布式系统下各个服务器之间Session不统一的问题
    Mac VMware Fusion 11.5 虚拟机带密钥
    快速排序
    python装饰器
    商品详情页
    hadoop跑wordcount报expected org.apache.hadoop.io.Text, received org.apache.hadoop.io.LongWritable
    CentOS7设置共享文件夹不显示问题
  • 原文地址:https://www.cnblogs.com/ly570/p/11198596.html
Copyright © 2011-2022 走看看