zoukankan      html  css  js  c++  java
  • 梯度下降算法

    给定一个二次函数y=x*(x-3),问当x等于几的时候,y取得-0.5
    这个问题其实就是一个解一元非齐次方程,可以用梯度下降算法。

    梯度下降算法第一步就是要定义好一个目标:
    g(x)=1/2(target-f(x))^2
    它的导数为g'(x)=(target-f(x))
    f'(x)
    每次调整x时:x=x-alpha*g'(x),意思是梯度越大,需要调整的x越大;梯度越小,需要调整的x越小。

    定义目标函数时,往往采用二范数的形式,因为这种形式求导简便,并且保证导数平滑连续。

    梯度下降算法上可以利用的技巧非常多,比如带动量、学习率递减等。

    def f(x):  # 原函数
        return x * (x - 3)
    
    
    def ff(x):  # 原函数的导数
        return 2 * x - 3
    
    
    x = 0  # 初始值
    eps = 0.001
    eta = 0.2  # 学习率
    target = -0.5  # 学习目标
    cnt = 0
    while 1:
        print('第', cnt, '轮', x)
        cnt += 1
        xx = x + eta * ff(x) * (target - f(x))  # x2=x1+eta*f'(x1)*dy
        if abs(x - xx) < eps:
            break
        else:
            x = xx
    print('最终结果', x, f(x))
    
    

    输出结果

    第 0 轮 0
    第 1 轮 0.30000000000000004
    第 2 轮 0.15119999999999997
    第 3 轮 0.18856793210880002
    第 4 轮 0.17275419569602604
    第 5 轮 0.17890275481423512
    第 6 轮 0.17641799877866612
    最终结果 0.17641799877866612 -0.4981306860429289
    

    如果学习率不是递减的,而eps又很小,最后会永远振荡下去。所以实践中,学习率往往是递减的。

    带动量的梯度下降算法

    只需要改动一句话,添加一个mementum变量

    def f(x):  # 原函数
        return x * (x - 3)
    
    
    def ff(x):  # 原函数的导数
        return 2 * x - 3
    
    
    x = 0  # 初始值
    eps = 0.001
    eta = 0.2  # 学习率
    target = -0.5  # 学习目标
    momentum = 0.2
    cnt = 0
    while 1:
        print('第', cnt, '轮', x)
        cnt += 1
        xx = x + eta * ff(x) * (target - f(x))  # x2=x1+eta*f'(x1)*dy
        if abs(x - xx) < eps:
            break
        else:
            x = xx * (1 - momentum) + x * momentum
    print('最终结果', x, f(x))
    

    输出结果

    第 0 轮 0
    第 1 轮 0.24000000000000005
    第 2 轮 0.17452032
    第 3 轮 0.17744544458549058
    最终结果 0.17744544458549058 -0.5008494479523293
    
  • 相关阅读:
    Android四大基本组件介绍与生命周期
    TRIZ系列-创新原理-23-反馈原理
    hibernate之6.one2many单向
    软件评測师真题考试分析-5
    WAS集群系列(3):集群搭建:步骤1:准备文件
    Android Developer:合并清单文件
    移动均值滤波与中值滤波
    使用React的static方法实现同构以及同构的常见问题
    mysql合并同一列的值
    iOS开发
  • 原文地址:https://www.cnblogs.com/weiyinfu/p/7705248.html
Copyright © 2011-2022 走看看