zoukankan      html  css  js  c++  java
  • 【DeepLearning】优化算法:SGD、GD、mini-batch GD、Moment、RMSprob、Adam

    优化算法

    1 GD/SGD/mini-batch GD

    GD:Gradient Descent,就是传统意义上的梯度下降,也叫batch GD。

    SGD:随机梯度下降。一次只随机选择一个样本进行训练和梯度更新。

    mini-batch GD:小批量梯度下降。GD训练的每次迭代一定是向着最优方向前进,但SGD和mini-batch GD不一定,可能会”震荡“。把所有样本一次放进网络,占用太多内存,甚至内存容纳不下如此大的数据量,因此可以分批次训练。可见,SGD是mini-batch GD的特例。

    2 动量梯度下降

    一个有用的方法:指数加权平均

    假如有N天的数据,(a_1,a_2,a_3,...,a_N)

    如果想拟合一条比较平滑的曲线,怎么做呢。可以使用指数加权平均。概括的说,就是平均前m天的数据作为一个数据点:

    (b_0=0)

    (b_t = eta*b_{t-1} + (1-eta)*a_t)

    比如(eta)取0.9,那么就相当于是平均了10天的数据,这会比较平滑。

    为什么叫指数加权平均?是因为,将(b_t)计算公式中的(b_{t-1})展开,依次展开下去,会发现:

    (b_t = eta*(eta*b_{t-2}+(1-eta)*a_{t-1})+(1-eta)*a_t)

    合并一下,前面的数据乘以(eta)的m次方的平均值。因此叫做指数加权平均。

    指数加权平均的偏差修正:

    很显然,上面公式汇总,假如(eta=0.9),是要平均前10天的数据,则前9天的数据,做加权平均,并没有足够的数据,这会导致前九天数据偏小。举例来说,第一天(b_1)仅仅是(a_1)的十分之一。可以做偏差修正:

    (b_0=0)

    (b_t = eta*b_{t-1} + (1-eta)*a_t)

    (b_t = frac{b_t}{1-eta ^t})

    有了指数加权平均的基本知识,就可以讲Moment Gradient Descent。

    带动量的梯度下降,他是为了加快学习速度的优化算法。假如参数W方向希望学的快一点,b方向学的慢一点。普通的梯度下降的结果可能恰恰相反,使得b方向学的比较快,从而引发震荡。所以想要让b方向学的慢一点,这时可以用动量梯度下降。算法如下:

    For each epco (t):

    ​ cal (dW,dB) for each mini-batch.

    (V_{dW}=eta*V_{d_W}+(1-eta)*dW)

    (V_{db}=eta*V_{d_b}+(1-eta)*db)

    (W = W-alpha*V_{dW})

    (b=b-alpha*V_{db})

    可见,就是将梯度做了指数加权平均。至于为什么叫动量梯度下降,可能是因为(V_{dW})的计算式中,把(dW)看作加速度,把(V_{dW})看作速度。

    3 RMSprob

    Root Mean Square prob。

    这是另一种加快学习的方法。其基本思想也是让b学的慢一点,让W学的快一点,从而更快更准的趋向于最优点。

    For each epco (t):

    ​ cal (dW,dB) for each mini-batch.

    (S_{dW}=eta*S_{dW}+(1-eta)*(dW)^2)

    (S_{db}=eta*S_{db}+(1-eta)*(db)^2)

    (W = W-alpha*frac{dW}{sqrt{S_{dW}}})

    (b = b-alpha*frac{db}{sqrt{S_{db}}})

    可见,当梯度较大,则会使得(S)较大,从而使得更新变缓慢。

    4 Adam

    Adaptive Moment Estimation

    这是比较普遍的适用于各种网络的一种方法。称作自适应的动量梯度下降。这是上面动量梯度下降和RMSprob的结合版本,效果比较好。两者做加权指数平均的时候,都做了修正。

    For each epco (t):

    ​ cal (dW,dB) for each mini-batch.

    (V_{dW}=eta_1*V_{d_W}+(1-eta_1)*dW)

    (V_{db}=eta_1*V_{d_b}+(1-eta_1)*db)

    (S_{dW}=eta_2*S_{dW}+(1-eta_2)*(dW)^2)

    (S_{db}=eta_2*S_{db}+(1-eta_2)*(db)^2)

    (V_{dW}^{correction} = frac{V_{dW}}{1-eta_1^t})

    (V_{db}^{correction} = frac{V_{db}}{1-eta1^t})

    (S_{dW}^{correction} = frac{S_{dW}}{1-eta_2^t})

    (S_{db}^{correction} = frac{S_{db}}{1-eta_2^t})

    (W = W-alpha*frac{V_{dW}^{correction}}{sqrt{S_{dW}^{correcti}}+varepsilon})

    (b = b-alpha*frac{V_{db}^{correction}}{sqrt{S_{db}^{correcti}}+varepsilon})

    可见,就是在RMSprob中,用动量梯度下降中的梯度指数加权代替梯度,同时所有指数加权项都做了偏差修正。另外,分母加了(varepsilon),这是为了防止除以很小的数造成不稳定。公式中共有4个超参数,他们的取值经验是:

    (alpha):学习率,需要调试

    (eta_1):取0.9

    (eta_2):Adam的作者建议取0.999

    (varepsilon):取(10^{-8}),并不影响学习效果。

    另外,值得注意的是,学习过程中可能有两种窘境,一是困在局部最优,另一个是遇平缓区学习的过慢。采用mini-batch下前者出现的概率很小,即使困在最优也能跳出来,前提是数据量足够。但后者比较棘手,Adam是个比较好的优化算法,一定程度上解决了这个问题。

    5 学习率衰减

    训练过程中,随着迭代次数,学习率相应减少。

    在FB的一篇论文中,使用了不同时间段(迭代次数区间段)采用不同的学习率的方法,效果比较好。

  • 相关阅读:
    C#趣味程序---车牌号推断
    使用 C# 开发智能手机软件:推箱子(十四)
    【Oracle错误集锦】:ORA-12154: TNS: 无法解析指定的连接标识符
    java中你确定用对单例了吗?
    linux tty设置详解
    tty linux 打开和设置范例
    C和C++之间库的互相调用
    Android 编译参数 LOCAL_MODULE_TAGS
    pthread_once 和 pthread_key
    Android系统root破解原理分析
  • 原文地址:https://www.cnblogs.com/duye/p/10595659.html
Copyright © 2011-2022 走看看