zoukankan      html  css  js  c++  java
  • 不依赖Python第三方库实现梯度下降

    认识

    梯度的本意是一个向量(矢量),表示某一函数在该点处的方向导数沿着该方向取得最大值,即函数在该点处沿着该方向(此梯度的方向)变化最快,变化率最大(为该梯度的模), 我感觉, 其实就是偏导数向量方向呗, 沿着这个向量方向可以找到局部的极值.

    from random import random
    
    def gradient_down(func, part_df_func, var_num, rate=0.1, max_iter=10000, tolerance=1e-10):
        """
        不依赖第三库实现梯度下降
        :param func: 损失(误差)函数
        :param part_df_func: 损失函数的偏导数向量
        :param var_num: 变量个数
        :param rate: 学习率(参数的每次变化的幅度)
        :param max_iter: 最大计算次数
        :param tolerance: 误差的精度
        :return: theta, y_current:  权重参数值列表, 损失函数最小值
        """
    
        theta = [random() for _ in range(var_num)]  # 随机给定参数的初始值
        y_current = func(*theta)  # 参数解包
    
        for i in range(max_iter):
            # 计算当前参数的梯度(偏导数导数向量值)
            gradient = [f(*theta) for f in part_df_func]
            # 根据梯度更新参数 theta
            for j in range(var_num):
    
                theta[j] -= gradient[j] * rate  # [0.3, 0.6, 0.7] ==> [0.3-0.3*lr, 0.6-0.6*lr, 0.7-0.7*lr]
    
                y_current,  y_predict = func(*theta), y_current
                print(f"正在进行第{i}次迭代, 误差精度为{abs(y_predict - y_current)}")
    
                if abs(y_predict - y_current) < tolerance:   # 判断是否收敛, (误差值的精度)
    
                    print(); print(f"ok, 在第{i}次迭代, 收敛到可以了哦!")
    
                    return theta, y_current
    
    
    def f(x, y):
        """原函数"""
        return (x + y - 3) ** 2 + (x + 2 * y - 5) ** 2 + 2
    
    
    def df_dx(x, y):
        """对x求偏导数"""
        return 2 * (x + y - 3) + 2 * (x + 2 * y - 5)
    
    
    def df_dy(x, y):
        """对y求偏导数, 注意求导的链式法则哦"""
        return 2 * (x + y - 3) + 2 * (x + 2 * y - 5) * 2
    
    
    def main():
        """主函数"""
        print("用梯度下降的方式求解函数的最小值哦:")
        theta, f_theta = gradient_down(f, [df_dx, df_dy], var_num=2)
    
        theta, f_theta = [round(i, 3) for i in theta], round(f_theta, 3)  # 保留3位小数
    
        print("该函数最优解是: 当theta取:{}时,f(theta)取到最小值:{}".format(theta, f_theta))
    
    
    if __name__ == '__main__':
        main()
    
    
    ...
    ...
    正在进行第248次迭代, 误差精度为1.6640999689343516e-10
    正在进行第249次迭代, 误差精度为1.5684031851037616e-10
    正在进行第250次迭代, 误差精度为1.478208666583214e-10
    正在进行第251次迭代, 误差精度为1.3931966691416164e-10
    正在进行第252次迭代, 误差精度为1.3130829756846651e-10
    正在进行第253次迭代, 误差精度为1.2375700464417605e-10
    正在进行第254次迭代, 误差精度为1.166395868779091e-10
    正在进行第255次迭代, 误差精度为1.0993206345233375e-10
    正在进行第256次迭代, 误差精度为1.0361000946090826e-10
    正在进行第257次迭代, 误差精度为9.765166453234997e-11
    
    ok, 在第257次迭代, 收敛到可以了哦!
    该函数最优解是: 当theta取:[1.0, 2.0]时,f(theta)取到最小值:2.0
    [Finished in 0.0s]
    
  • 相关阅读:
    Zookeeper实战
    Zookeeper的结构和命令
    Zookeeper中的选举机制
    du 命令,对文件和目录磁盘使用的空间的查看
    rm命令
    linux之cp/scp命令+scp命令详解
    android 为应用程序创建桌面快捷方式技巧分享
    对自己的文件使用keystore签名
    Android 打包签名 从生成keystore到完成签名 -- 转
    Android App启动错误的问题(connection to the server was unsuccessful)
  • 原文地址:https://www.cnblogs.com/chenjieyouge/p/11667959.html
Copyright © 2011-2022 走看看