zoukankan      html  css  js  c++  java
  • Gradient Descent

    from

    https://ml-cheatsheet.readthedocs.io/en/latest/gradient_descent.html#step-by-step

    Gradient Descent

    Gradient descent is an optimization algorithm used to minimize some function by iteratively moving in the direction of steepest descent as defined by the negative of the gradient. In machine learning, we use gradient descent to update the parameters of our model. Parameters refer to coefficients in Linear Regression and weights in neural networks.

    Introduction

    Consider the 3-dimensional graph below in the context of a cost function. Our goal is to move from the mountain in the top right corner (high cost) to the dark blue sea in the bottom left (low cost). The arrows represent the direction of steepest descent (negative gradient) from any given point–the direction that decreases the cost function as quickly as possible. Source

    _images/gradient_descent.png

    Starting at the top of the mountain, we take our first step downhill in the direction specified by the negative gradient. Next we recalculate the negative gradient (passing in the coordinates of our new point) and take another step in the direction it specifies. We continue this process iteratively until we get to the bottom of our graph, or to a point where we can no longer move downhill–a local minimum. image source.

    _images/gradient_descent_demystified.png

    Learning rate

    The size of these steps is called the learning rate. With a high learning rate we can cover more ground each step, but we risk overshooting the lowest point since the slope of the hill is constantly changing. With a very low learning rate, we can confidently move in the direction of the negative gradient since we are recalculating it so frequently. A low learning rate is more precise, but calculating the gradient is time-consuming, so it will take us a very long time to get to the bottom.

    Cost function

    A Loss Functions tells us “how good” our model is at making predictions for a given set of parameters. The cost function has its own curve and its own gradients. The slope of this curve tells us how to update our parameters to make the model more accurate.

    Step-by-step

    Now let’s run gradient descent using our new cost function. There are two parameters in our cost function we can control: m

    (weight) and b

    (bias). Since we need to consider the impact each one has on the final prediction, we need to use partial derivatives. We calculate the partial derivatives of the cost function with respect to each parameter and store the results in a gradient.

    Math

    Given the cost function:

    f(m,b)=1Ni=1n(yi(mxi+b))2

    The gradient can be calculated as:

    f(m,b)=dfdmdfdb=[1N2xi(yi(mxi+b))1N2(yi(mxi+b))]

    To solve for the gradient, we iterate through our data points using our new m

    and b

    values and compute the partial derivatives. This new gradient tells us the slope of our cost function at our current position (current parameter values) and the direction we should move to update our parameters. The size of our update is controlled by the learning rate.

    Code

    def update_weights(m, b, X, Y, learning_rate):
        m_deriv = 0
        b_deriv = 0
        N = len(X)
        for i in range(N):
            # Calculate partial derivatives
            # -2x(y - (mx + b))
            m_deriv += -2*X[i] * (Y[i] - (m*X[i] + b))
    
            # -2(y - (mx + b))
            b_deriv += -2*(Y[i] - (m*X[i] + b))
    
        # We subtract because the derivatives point in direction of steepest ascent
        m -= (m_deriv / float(N)) * learning_rate
        b -= (b_deriv / float(N)) * learning_rate
    
        return m, b
    

    References

    出处:http://www.cnblogs.com/lightsong/ 本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接。
  • 相关阅读:
    【原创】编程题练习:反转字符串中的单词
    【最近的学习安排】
    【转载】判断两个链表是否相交、一个链表是否有环
    获取字符串字节长度
    如何找到GridView里的控件,建立GridViewRow对象
    Flex和.NET协同开发利器FluorineFx Flex与.NET互操作
    两款基于Visual Studio开发Flex的插件
    Mysql:向信号量添加给定计数将导致其超出它的最大计数错误
    Sql Server数据库触发器实例
    国外一些知名ASP.Net开源CMS系统
  • 原文地址:https://www.cnblogs.com/lightsong/p/14643140.html
Copyright © 2011-2022 走看看