zoukankan      html  css  js  c++  java
  • 两个变量(可支持多自变量)的简单梯度下降

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d.axes3d import Axes3D
    
    #   公式 f(x,y) = 2x^2+6y^2+6xy+x+4y+8
    def targetFunc(x,y):
        return 2*(x**2)+6*y**2+6*x*y+x+4*y+8
    
    #   偏导
    #   f'x(x,y)=4x+6y+1
    #   f'y(x,y)=12y+6x+4
    def derivativeFunc(x,y):
        rx = 4*x+6*y+1
        ry = 12*y+6*x+4
        return (rx,ry)
    
    pointList = []
    
    def linerFunc(initPoint:tuple,targetFunc,derivativeFunc,step = 0.01,limitValue = 0.00000001,timeout=1000000,ax:Axes3D = None):
        count = 1
        initPoint = np.array(initPoint)
        ro,do = targetFunc(*initPoint),np.array(derivativeFunc(*initPoint))
        pointList.append((*initPoint, ro))
    
        newPoint = initPoint-do*step
        rn,dn = targetFunc(*newPoint),np.array(derivativeFunc(*newPoint))
    
        diff = np.abs(np.array(do-dn))
    
        while (diff > limitValue).any() and count < timeout:
            # print(initPoint)
            initPoint = newPoint
            ro, do = targetFunc(*initPoint), np.array(derivativeFunc(*initPoint))
    
            newPoint = np.where(np.abs(do*step) >= limitValue,initPoint-do*step,initPoint)
            rn, dn = targetFunc(*newPoint), np.array(derivativeFunc(*newPoint))
            diff = np.abs(np.array(do - dn))
    
            pointList.append((*initPoint, ro))
            count+=1
            pass
        print("最终运算次数为 : {0}".format(count))
        return rn,newPoint
        pass
    
    
    if __name__=="__main__":
        x,y = np.linspace(-2,23,100),np.linspace(-2,23,100)
        x,y = np.meshgrid(x,y)
        fxy=targetFunc(x,y)
    
        fig = plt.figure()
        ax = Axes3D(fig)
    
        ax.plot_surface(x, y, fxy)
        limitValue,limitPoint = linerFunc((20,20),targetFunc,derivativeFunc,ax=ax)
        ax.scatter(*(np.array(pointList).T),c='r',s=20)
        print("该函数在({0},{1})处有驻点,值为{2}".format(limitPoint[0],limitPoint[1],limitValue))
        ax.legend()
        plt.show()
        pass

  • 相关阅读:
    对JAVA集合进行遍历删除时务必要用迭代器
    【javascript的那些事】等待加载完js后执行方法
    【微信H5】 Redirect_uri参数错误解决方法
    关于Java优质代码的那些事
    H5页面微信分享和手Q分享设置
    LVS+keepalived-DR模式
    Linux文件误删恢复
    MySQL常用语句
    sudo权限配置
    Rsync同步部署web服务端配置
  • 原文地址:https://www.cnblogs.com/dofstar/p/11462941.html
Copyright © 2011-2022 走看看