zoukankan      html  css  js  c++  java
  • Python 梯度下降法

    题目描述:
    自定义一个可微并且存在最小值的一元函数,用梯度下降法求其最小值。并绘制出学习率从0.1到0.9(步长0.1)时,达到最小值时所迭代的次数的关系曲线,根据该曲线给出简单的分析。

    代码:

    # -*- coding: utf-8 -*-
    """
    Created on Tue Jun  4 10:19:03 2019
    
    @author: Administrator
    """
    
    import numpy as np
    import matplotlib.pyplot as plt
    plot_x=np.linspace(-1,6,150)   #在-1到6之间等距的生成150个数
    plot_y=(plot_x-2.5)**2+3	   # 同时根据plot_x来生成plot_y(y=(x-2.5)²+3)
    
    plt.plot(plot_x,plot_y)
    plt.show()
    
    ###定义一个求二次函数导数的函数dJ
    def dJ(x):
        return 2*(x-2.5)
    
    ###定义一个求函数值的函数J
    def J(x):
        try:
            return (x-2.5)**2+3
        except:
            return float('inf')
    
    x=0.0							#随机选取一个起始点
    eta=0.1						    #eta是学习率,用来控制步长的大小
    epsilon=1e-8				    #用来判断是否到达二次函数的最小值点的条件
    history_x=[x]                   #用来记录使用梯度下降法走过的点的X坐标
    count=0
    min=0
    while True:
        gradient=dJ(x)				#梯度(导数)
        last_x=x
        x=x-eta*gradient
        history_x.append(x)
        count=count+1
        if (abs(J(last_x)-J(x)) <epsilon):		#用来判断是否逼近最低点
            min=x
            break
        
    plt.plot(plot_x,plot_y)     
    plt.plot(np.array(history_x),J(np.array(history_x)),color='r',marker='*')   #绘制x的轨迹
    plt.show()
    
    print'min_x =',(min)
    print'min_y =',(J(min))	        #打印到达最低点时y的值
    print'count =',(count)
    
    sum_eta=[]
    result=[]
    for i in range(1,10,1):
        x=0.0							#随机选取一个起始点
        eta=i*0.1
        sum_eta.append(eta)
        epsilon=1e-8				    #用来判断是否到达二次函数的最小值点的条件
        num=0
        min=0
        while True:
            gradient=dJ(x)				#梯度(导数)
            last_x=x
            x=x-eta*gradient
            num=num+1
            if (abs(J(last_x)-J(x)) <epsilon):		#用来判断是否逼近最低点
                min=x
                break
        
        result.append(num)#记录学习率从0.1到0.9(步长0.1)时,达到最小值时所迭代的次数
    
    plt.scatter(sum_eta,result,c='r')
    plt.plot(sum_eta,result,c='r')
    plt.title("relation")
    plt.xlabel("eta")
    plt.ylabel("count")
    plt.show
    print(result)
    

      

    运行结果:

    结果分析:
    函数y=(x-2.5)²+3从学习率和迭代次数的关系图上我们可以知道当学习率较低时迭代次数较多,随着学习率的增大,迭代次数开始逐渐减少,当学习率为0.5时迭代次数最少,之后随着学习率的增加,迭代次数开始增加,当学习率为0.9时迭代次数和0.1时相等。关于0.5成对称分布。


    原文:https://blog.csdn.net/Ferryman23333/article/details/91050219

  • 相关阅读:
    Menu-actionBarMenu字体颜色修改
    actionBarTab-actionBarTab自定义 布局没法改变其中字体相对中间的位置
    Funui-overlay 如何添加theme 的 overlay
    java进阶——反射(Reflect)
    java工具类学习整理——集合
    Java实例练习——java实现自动生成长度为10以内的随机字符串(可用于生成随机密码)
    打通Java与MySQL的桥梁——jdbc
    SQL数据库操作整理
    PhpStorm 4.0 & 5.0 部署本地Web应用
    PhpStorm 4.0 & 5.0 部署本地Web应用
  • 原文地址:https://www.cnblogs.com/qbdj/p/10998909.html
Copyright © 2011-2022 走看看