zoukankan      html  css  js  c++  java
  • A review of gradient descent optimization methods

    Suppose we are going to optimize a parameterized function (J( heta)), where ( heta in mathbb{R}^d), for example, ( heta) could be a neural net.

    More specifically, we want to (mbox{ minimize } J( heta; mathcal{D})) on dataset (mathcal{D}), where each point in (mathcal{D}) is a pair ((x_i, y_i)).

    There are different ways to apply gradient descent.

    Let (eta) be the learning rate.

    1. Vanilla batch update
      ( heta gets heta - eta abla J( heta; mathcal{D}))
      Note that ( abla J( heta; mathcal{D})) computes the gradient on of the whole dataset (mathcal{D}).
        for i in range(n_epochs): 
            gradient = compute_gradient(J, theta, D)
            theta = theta - eta * gradient
            eta = eta * 0.95
    

    It is obvious that when (mathcal{D}) is too large, this approach is unfeasible.

    1. Stochastic Gradient Descent
      Stochastic Gradient, on the other hand, update the parameters example by example.
      ( heta gets heta - eta *J( heta, x_i, y_i)), where ((x_i, y_i) in mathcal{D}).
        for n in range(n_epochs):
            for x_i, y_i in D: 
                gradient=compute_gradient(J, theta, x_i, y_i)
                theta = theta - eta * gradient 
            eta = eta * 0.95 
    
    1. Mini-batch Stochastic Gradient Descent
      Update ( heta) example by example could lead to high variance, the alternative approach is to update ( heta) by mini-batches (M) where (|M| ll |mathcal{D}|).
        for n in range(n_epochs):
            for M in D: 
                gradient = compute_gradient(J, M)
                theta = theta - eta * gradient 
            eta = eta * 0.95
    

    Question? Why decaying the learning rate leads to convergence?
    why (sum_{i=1}^{infty} eta_i = infty) and (sum_{i=1}^{infty} eta_i^2 < infty) is the condition for convergence? Based on what assumption of (J( heta))?

  • 相关阅读:
    Win7双击任务栏图标导致窗口还原的问题
    一致性哈希算法及其在分布式系统中的应用(转)
    CAP理论(转)
    从Android界面开发谈起(转)
    Android开发入门之Window 环境概念介绍(转)
    数据库缓存技术(转)
    VoltDB开篇 简介(转)
    window下如何让php支持openssl(转)
    mysql分表的3种方法(转)
    linux crontab 每10秒执行一次
  • 原文地址:https://www.cnblogs.com/gaoqichao/p/9153675.html
Copyright © 2011-2022 走看看