zoukankan      html  css  js  c++  java
  • BGD SGD MBGD 与梯度下降算法 的理解

    梯度下降算法,用于求函数 极小值,以及极小值点对应的 变量值。可以是线性,非线性等。

    def fun(x,y):

        return 3*x*x+2*y*y


    def step(x,y):

    # 定义学习率a

        a=0.02

        x=x-a*(6*x)

        y=y-a*(4*y)

        return x,y


    不断调用step就可以求出函数的极小值以及极小值x,y的值。


    BGD SGD MBGD 梯度下降算法针对线性回归,其损失函数极小值就是最小值。

    J(θ1,θ2,θ3) 是包含θ1,θ2,θ3,以及x,y(数据点)组成的函数。已知数据点xy,求合适的θ值,使函数J最小。

    用θ的梯度去更新θ的值。

    BGD就是传统的梯度下降,损失函数J就是使用了全部的xy数据(函数J里的Σ包含所有的xy数据),SGD去掉了Σ(单个数据),MBGD的Σ比BGD的Σ少(使用部分xy数据)。



    BGD SGD MBGD 梯度下降算法 线性回归 python 实现


    import matplotlib.pyplot as plt
    import numpy as np
    
    x=np.random.random(100)
    y=5+3*x+np.random.normal(0,0.2,100)
    
    
    def fun(theta1,theta2,x):
        return theta1+theta2*x
    
    
    # BGD
    
    theta1=0
    theta2=0
    a=0.2
    
    #画图用
    BGD_path1=[]
    BGD_path2=[]
    
    for i in range(1000):
        temp1=theta1-a*(sum((fun(theta1,theta2,x)-y))/100)
        temp2=theta2-a*(sum(x*(fun(theta1,theta2,x)-y))/100)
        theta1=temp1
        theta2=temp2
        print(theta1,theta2)
        BGD_path1.append(theta1)
        BGD_path2.append(theta2)
    
    # plt.scatter(x,y)
    # plt.plot(x,fun(theta1,theta2,x))
    # plt.show()
    
    # SGD
    
    theta1=0
    theta2=0
    a=0.2
    
    #画图用
    SGD_path1=[]
    SGD_path2=[]
    
    for i in range(1):
        for j in range(100):
            print(j)
            temp1=theta1-a*(fun(theta1,theta2,x[j])-y[j])
            temp2=theta2-a*(x[j]*(fun(theta1,theta2,x[j])-y[j]))
            theta1=temp1
            theta2=temp2
            print(theta1,theta2)
            SGD_path1.append(theta1)
            SGD_path2.append(theta2)
    
    # plt.scatter(x,y)
    # plt.plot(x,fun(theta1,theta2,x))
    # plt.show()
    
    # MBGD 
    # mini-batch = 10 
    
    theta1=0
    theta2=0
    a=0.2
    
    #画图用
    MBGD_path1=[]
    MBGD_path2=[]
    
    for i in range(100):
        for j in range(10):
            temp1=theta1-a*(sum((fun(theta1,theta2,x[j*10:j*10+10])-y[j*10:j*10+10]))/100)
            temp2=theta2-a*(sum(x[j*10:j*10+10]*(fun(theta1,theta2,x[j*10:j*10+10])-y[j*10:j*10+10]))/100)
            theta1=temp1
            theta2=temp2
            print(theta1,theta2)
            MBGD_path1.append(theta1)
            MBGD_path2.append(theta2)
    
    # plt.scatter(x,y)
    # plt.plot(x,fun(theta1,theta2,x))
    # plt.show()
    
    from mpl_toolkits.mplot3d import Axes3D
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.plot(BGD_path1,BGD_path2,1,label='BGD')
    ax.plot(SGD_path1,SGD_path2,2,label='SGD')
    ax.plot(MBGD_path1,MBGD_path2,3,label='MBGD')
    plt.legend()
    plt.show()
  • 相关阅读:
    [六省联考2017]相逢是问候
    [CQOI2017]老C的键盘
    [CQOI2017]老C的任务
    [CQOI2017]小Q的棋盘
    <sdoi2017>树点涂色
    三分法
    最长回文子串
    hdu3261
    spoj694
    poj1743
  • 原文地址:https://www.cnblogs.com/XUEYEYU/p/12905021.html
Copyright © 2011-2022 走看看