zoukankan      html  css  js  c++  java
  • BGD SGD MBGD 与梯度下降算法 的理解

    梯度下降算法,用于求函数 极小值,以及极小值点对应的 变量值。可以是线性,非线性等。

    def fun(x,y):

        return 3*x*x+2*y*y


    def step(x,y):

    # 定义学习率a

        a=0.02

        x=x-a*(6*x)

        y=y-a*(4*y)

        return x,y


    不断调用step就可以求出函数的极小值以及极小值x,y的值。


    BGD SGD MBGD 梯度下降算法针对线性回归,其损失函数极小值就是最小值。

    J(θ1,θ2,θ3) 是包含θ1,θ2,θ3,以及x,y(数据点)组成的函数。已知数据点xy,求合适的θ值,使函数J最小。

    用θ的梯度去更新θ的值。

    BGD就是传统的梯度下降,损失函数J就是使用了全部的xy数据(函数J里的Σ包含所有的xy数据),SGD去掉了Σ(单个数据),MBGD的Σ比BGD的Σ少(使用部分xy数据)。



    BGD SGD MBGD 梯度下降算法 线性回归 python 实现


    import matplotlib.pyplot as plt
    import numpy as np
    
    x=np.random.random(100)
    y=5+3*x+np.random.normal(0,0.2,100)
    
    
    def fun(theta1,theta2,x):
        return theta1+theta2*x
    
    
    # BGD
    
    theta1=0
    theta2=0
    a=0.2
    
    #画图用
    BGD_path1=[]
    BGD_path2=[]
    
    for i in range(1000):
        temp1=theta1-a*(sum((fun(theta1,theta2,x)-y))/100)
        temp2=theta2-a*(sum(x*(fun(theta1,theta2,x)-y))/100)
        theta1=temp1
        theta2=temp2
        print(theta1,theta2)
        BGD_path1.append(theta1)
        BGD_path2.append(theta2)
    
    # plt.scatter(x,y)
    # plt.plot(x,fun(theta1,theta2,x))
    # plt.show()
    
    # SGD
    
    theta1=0
    theta2=0
    a=0.2
    
    #画图用
    SGD_path1=[]
    SGD_path2=[]
    
    for i in range(1):
        for j in range(100):
            print(j)
            temp1=theta1-a*(fun(theta1,theta2,x[j])-y[j])
            temp2=theta2-a*(x[j]*(fun(theta1,theta2,x[j])-y[j]))
            theta1=temp1
            theta2=temp2
            print(theta1,theta2)
            SGD_path1.append(theta1)
            SGD_path2.append(theta2)
    
    # plt.scatter(x,y)
    # plt.plot(x,fun(theta1,theta2,x))
    # plt.show()
    
    # MBGD 
    # mini-batch = 10 
    
    theta1=0
    theta2=0
    a=0.2
    
    #画图用
    MBGD_path1=[]
    MBGD_path2=[]
    
    for i in range(100):
        for j in range(10):
            temp1=theta1-a*(sum((fun(theta1,theta2,x[j*10:j*10+10])-y[j*10:j*10+10]))/100)
            temp2=theta2-a*(sum(x[j*10:j*10+10]*(fun(theta1,theta2,x[j*10:j*10+10])-y[j*10:j*10+10]))/100)
            theta1=temp1
            theta2=temp2
            print(theta1,theta2)
            MBGD_path1.append(theta1)
            MBGD_path2.append(theta2)
    
    # plt.scatter(x,y)
    # plt.plot(x,fun(theta1,theta2,x))
    # plt.show()
    
    from mpl_toolkits.mplot3d import Axes3D
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.plot(BGD_path1,BGD_path2,1,label='BGD')
    ax.plot(SGD_path1,SGD_path2,2,label='SGD')
    ax.plot(MBGD_path1,MBGD_path2,3,label='MBGD')
    plt.legend()
    plt.show()
  • 相关阅读:
    【转】算命称骨
    为WinForms程序添加Form级快捷键的最简单方式
    Flex4中Datagrid垂直滚动使combobox&dropdownlist数据消失(已解决)
    Flex & form小技巧
    [转]Android开发之旅:环境搭建及HelloWorld
    [SQL2005]安装Ms SQL Server 2005 开发版时出现性能计数器要求安装错误的解决办法
    [转]Flex 中的皮肤
    天哪,Flex有没有阻塞方式啊?
    [资料库]Flash特效站点(带源码)
    开始学习JAVA
  • 原文地址:https://www.cnblogs.com/XUEYEYU/p/12905021.html
Copyright © 2011-2022 走看看