zoukankan      html  css  js  c++  java
  • 两个变量(可支持多自变量)的简单梯度下降

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d.axes3d import Axes3D
    
    #   公式 f(x,y) = 2x^2+6y^2+6xy+x+4y+8
    def targetFunc(x,y):
        return 2*(x**2)+6*y**2+6*x*y+x+4*y+8
    
    #   偏导
    #   f'x(x,y)=4x+6y+1
    #   f'y(x,y)=12y+6x+4
    def derivativeFunc(x,y):
        rx = 4*x+6*y+1
        ry = 12*y+6*x+4
        return (rx,ry)
    
    pointList = []
    
    def linerFunc(initPoint:tuple,targetFunc,derivativeFunc,step = 0.01,limitValue = 0.00000001,timeout=1000000,ax:Axes3D = None):
        count = 1
        initPoint = np.array(initPoint)
        ro,do = targetFunc(*initPoint),np.array(derivativeFunc(*initPoint))
        pointList.append((*initPoint, ro))
    
        newPoint = initPoint-do*step
        rn,dn = targetFunc(*newPoint),np.array(derivativeFunc(*newPoint))
    
        diff = np.abs(np.array(do-dn))
    
        while (diff > limitValue).any() and count < timeout:
            # print(initPoint)
            initPoint = newPoint
            ro, do = targetFunc(*initPoint), np.array(derivativeFunc(*initPoint))
    
            newPoint = np.where(np.abs(do*step) >= limitValue,initPoint-do*step,initPoint)
            rn, dn = targetFunc(*newPoint), np.array(derivativeFunc(*newPoint))
            diff = np.abs(np.array(do - dn))
    
            pointList.append((*initPoint, ro))
            count+=1
            pass
        print("最终运算次数为 : {0}".format(count))
        return rn,newPoint
        pass
    
    
    if __name__=="__main__":
        x,y = np.linspace(-2,23,100),np.linspace(-2,23,100)
        x,y = np.meshgrid(x,y)
        fxy=targetFunc(x,y)
    
        fig = plt.figure()
        ax = Axes3D(fig)
    
        ax.plot_surface(x, y, fxy)
        limitValue,limitPoint = linerFunc((20,20),targetFunc,derivativeFunc,ax=ax)
        ax.scatter(*(np.array(pointList).T),c='r',s=20)
        print("该函数在({0},{1})处有驻点,值为{2}".format(limitPoint[0],limitPoint[1],limitValue))
        ax.legend()
        plt.show()
        pass

  • 相关阅读:
    Codeforces Round #281 (Div. 2) A. Vasya and Football(模拟)
    自动生成代码工具
    导入导出维护计划
    收集错误日志方法
    C#常用控件和属性
    人民币转换
    身份证验证
    设置下拉列表项的默认值
    清除维护任务
    清除MSSQL历史记录
  • 原文地址:https://www.cnblogs.com/dofstar/p/11462941.html
Copyright © 2011-2022 走看看