zoukankan      html  css  js  c++  java
  • Python 梯度下降法

    题目描述:
    自定义一个可微并且存在最小值的一元函数,用梯度下降法求其最小值。并绘制出学习率从0.1到0.9(步长0.1)时,达到最小值时所迭代的次数的关系曲线,根据该曲线给出简单的分析。

    代码:

    # -*- coding: utf-8 -*-
    """
    Created on Tue Jun  4 10:19:03 2019
    
    @author: Administrator
    """
    
    import numpy as np
    import matplotlib.pyplot as plt
    plot_x=np.linspace(-1,6,150)   #在-1到6之间等距的生成150个数
    plot_y=(plot_x-2.5)**2+3	   # 同时根据plot_x来生成plot_y(y=(x-2.5)²+3)
    
    plt.plot(plot_x,plot_y)
    plt.show()
    
    ###定义一个求二次函数导数的函数dJ
    def dJ(x):
        return 2*(x-2.5)
    
    ###定义一个求函数值的函数J
    def J(x):
        try:
            return (x-2.5)**2+3
        except:
            return float('inf')
    
    x=0.0							#随机选取一个起始点
    eta=0.1						    #eta是学习率,用来控制步长的大小
    epsilon=1e-8				    #用来判断是否到达二次函数的最小值点的条件
    history_x=[x]                   #用来记录使用梯度下降法走过的点的X坐标
    count=0
    min=0
    while True:
        gradient=dJ(x)				#梯度(导数)
        last_x=x
        x=x-eta*gradient
        history_x.append(x)
        count=count+1
        if (abs(J(last_x)-J(x)) <epsilon):		#用来判断是否逼近最低点
            min=x
            break
        
    plt.plot(plot_x,plot_y)     
    plt.plot(np.array(history_x),J(np.array(history_x)),color='r',marker='*')   #绘制x的轨迹
    plt.show()
    
    print'min_x =',(min)
    print'min_y =',(J(min))	        #打印到达最低点时y的值
    print'count =',(count)
    
    sum_eta=[]
    result=[]
    for i in range(1,10,1):
        x=0.0							#随机选取一个起始点
        eta=i*0.1
        sum_eta.append(eta)
        epsilon=1e-8				    #用来判断是否到达二次函数的最小值点的条件
        num=0
        min=0
        while True:
            gradient=dJ(x)				#梯度(导数)
            last_x=x
            x=x-eta*gradient
            num=num+1
            if (abs(J(last_x)-J(x)) <epsilon):		#用来判断是否逼近最低点
                min=x
                break
        
        result.append(num)#记录学习率从0.1到0.9(步长0.1)时,达到最小值时所迭代的次数
    
    plt.scatter(sum_eta,result,c='r')
    plt.plot(sum_eta,result,c='r')
    plt.title("relation")
    plt.xlabel("eta")
    plt.ylabel("count")
    plt.show
    print(result)
    

      

    运行结果:

    结果分析:
    函数y=(x-2.5)²+3从学习率和迭代次数的关系图上我们可以知道当学习率较低时迭代次数较多,随着学习率的增大,迭代次数开始逐渐减少,当学习率为0.5时迭代次数最少,之后随着学习率的增加,迭代次数开始增加,当学习率为0.9时迭代次数和0.1时相等。关于0.5成对称分布。


    原文:https://blog.csdn.net/Ferryman23333/article/details/91050219

  • 相关阅读:
    现在不知道为什么安装pip包总是失败,只能用清华源
    linux 下 svn配置;以及多仓库配置
    谷歌浏览器安装json格式化插件
    RESTful API的理解
    mysql5.6 rpm安装配置
    linux,apache,php,mysql常用的查看版本信息的方法
    mysql允许别人通过ip访问本机mysql数据
    直接取PHP二维数组里面的值
    mysql优化
    self this
  • 原文地址:https://www.cnblogs.com/qbdj/p/10998909.html
Copyright © 2011-2022 走看看