zoukankan      html  css  js  c++  java
  • 神经网络中batch_size参数的含义及设置方法

    本文作者Key,博客园主页:https://home.cnblogs.com/u/key1994/

    本内容为个人原创作品,转载请注明出处或联系:zhengzha16@163.com

    在进行神经网络训练时,batch_size是一个必须进行设置的参数。以前在用BP神经网络进行预测时,由于模型结构很简单,所以对的batch_size值的设置没太在意。最近在做YOLO这样的深度网络,模型结构本身比较复杂,且训练样本量较大,在训练时损失函数降得较慢。看网上有些文章说可以改变batch_size的值来提升训练效果,我尝试改变bz值之后发现确实有用,但对其中的原理依然不太懂,所以特地查了一些资料来搞明白这个问题,并在这里简单进行整理。

    Batch一般被翻译为批量,设置batch_size的目的让模型在训练过程中每次选择批量的数据来进行处理。一般机器学习或者深度学习训练过程中的目标函数可以简单理解为在每个训练集样本上得到的目标函数值的求和,然后根据目标函数的值进行权重值的调整,大部分时候是根据梯度下降法来进行参数更新的。

    为什么要引入batch_size:

    如果没有引入batch_size这一参数,那么在训练过程中所有的训练数据直接输入到网络,经过计算之后得到网络输出值及目标函数值,并以此来调整网络参数使目标函数取极小值。批量梯度下降算法(Batch Gradient Descent)正是这样的原理,注意这里的batch和batch_size中的batch并无直接联系,当然此时也可以理解为batch_size的值恰好等于训练集样本数量。

    那么这样做的缺点是什么呢?

    首先,当训练集样本非常多时,直接将这些数据输入到神经网络的话会导致计算量非常大,同时对极端集的内存要求也比较高。

    其次,当所有样本同时输入到网络中时,往往很难确定一个全局最优学习率使得训练效果最佳。

    另外一种极端情况是:每次只读取一个样本作为输入,这种方法称为随机梯度下降算法(Stochastic Gradient Descent, SGD)。这种情况下,可以充分考虑每一个样本的特殊性。但是其缺点同样非常明显:

    在每个训练样本上得到的目标函数值差别可能较大,因此最后通过求和或者求平均值的方法而得到的目标函数值不足以代表每个样本。也就是说,这种方法得到的模型对样本的泛化能力差。

    为了对两种极端情况进行折衷处理,就有了mini batches这一概念。也就是说每次只输入一定数量的训练样本对模型进行训练,这个数量就是batch_size的大小。

    这样做的优点主要有:

    • 可以充分利用计算机的并行运算结构,提高数据处理速度;
    • 考虑了一定数量的样本数据,可以比较准确得代表梯度下降方向
    • 跑完一次 epoch(全数据集)所需的迭代次数减少,对于相同数据量的处理速度进一步加快。

    但是,batch_size的大小不能无限增大,如果取过大的batch_size,会导致每个epoch迭代的次数减小,要想取得更好的训练效果,需要更多的epoch,会增大总体运算量和运算时间;此外,每次处理多张图片时,虽然可以发挥计算机并行计算的优势,但是也要充分考虑计算机内存大小的限制。

    另外,在对样本数据进行批量处理时还会产生另外一个问题:当我们采用样本增强技术或者训练样本中重复样本过多时,如果按顺序对输入样本进行批量,可能导致同一批数据相关性较高,即使采取了设置了batch_size值也只能代表某一小部分的数据,因此获取batch_size的数据时需要随机抽取。在实际中,可以通过对训练集样本进行shuffle来实现这一操作。

    当然这里还有一个问题我还没搞明白,就是一般我们在选取batch_size的值时往往采取2的幂数,常见的如16,32,64,128等。取这些值是为了充分发挥计算机的数据处理能力,哪位大神能给讲一下其中的原因?或者给推荐一些相关资料来搞明白这个问题?

    参考文献:

    1.  深度学习训练过程中BatchSize的设置

    2. 神经网络中Batch Size的理解

    3. 伊恩·古德费洛等,《深度学习》

  • 相关阅读:
    Vue 2.x windows环境下安装
    VSCODE官网下载缓慢或下载失败 解决办法
    angular cli 降级
    Win10 VS2019 设置 以管理员身份运行
    XSHELL 连接 阿里云ECS实例
    Chrome浏览器跨域设置
    DBeaver 执行 mysql 多条语句报错
    DBeaver 连接MySql 8.0 报错 Public Key Retrieval is not allowed
    DBeaver 连接MySql 8.0报错 Unable to load authentication plugin 'caching_sha2_password'
    Linux系统分区
  • 原文地址:https://www.cnblogs.com/key1994/p/11898304.html
Copyright © 2011-2022 走看看