zoukankan      html  css  js  c++  java
  • 二元函数的梯度下降法求解

    #第一种算法:两个方向不解耦同时进行梯度迭代求解
    import numpy as np
    class g:

    def test(self,x):
    e = 2.71828182845904590
    return x[0]**3+e**x[0]+x[1]**4+x[0]+x[1]-2


    def gradient_descent_step1(self,x):
    self.alpha=0.01
    return [x[0]+self.alpha,x[1]],[x[0]-self.alpha,x[1]],[x[0],x[1]+self.alpha],[x[0],x[1]-self.alpha]

    def outpower(self,x1,x2,x3,x4):
    self.y1 = self.test(x1)
    self.y2 = self.test(x2)
    self.y3 = self.test(x3)
    self.y4 = self.test(x4)
    return self.y1,self.y2,self.y3,self.y4

    def gradient_descent_step2(self,x,y1,y2,y3,y4):

    gradient_x0 = (y1-y2)/(2*self.alpha)
    gradient_x1 = (y3-y4)/(2*self.alpha)

    self.step= 0.01
    x[0] = x[0]-np.sign(gradient_x0)*self.step
    x[1] = x[1]-np.sign(gradient_x1)*self.step

    return [x[0],x[1]]

    x=[0.0,0.0]
    g=g()

    for i in range(40):
    x1,x2,x3,x4=g.gradient_descent_step1(x)
    y1, y2, y3, y4=g.outpower(x1,x2,x3,x4)
    x = g.gradient_descent_step2(x,y1,y2,y3,y4)
    print ('x={:6f},{:6f},problem(x)={:6f}'.format(x[0],x[1],g.test(x)))

    #第二种方法和思路:单轴迭代算法循环算法(解耦算XY法)

    #实时输出测量值函数(x,y为位置坐标,返回的score为此位置处的指标测量值大小)
    import numpy as np
    import time

    #测试函数输出某一个位置(x,y)处的耦合光强
    def output(x,y):
    #time.sleep(0.3)
    return -x**2+y**2

    #整体的扫描算法函数
    def test(x0,y0,fanx,fany):
    #定义初始位置x0,y0
    point_x=x0
    point_y=y0
    #定义迭代扫描的步长大小step
    step=0.001
    #step=max((fanx[1]-fanx[0])/40,(fany[1]-fany[0])/40)
    point_history=[]
    score_history=[]
    #定义目前的点数
    i=0
    #定义最大的迭代次数
    number=1e5
    #定义单轴的截至条件
    error=0.001
    #判断继续进行迭代的条件:在约束范围里并且迭代次数小于规定次数
    while (point_x<=fanx[1] and point_x>=fanx[0]) and (point_y<=fany[1] and point_y>=fany[0]) and i<number:
    #固定y轴,x轴单轴的最大值扫描
    while point_x<fanx[1] and point_x>fanx[0] and i<number:
    last_point_x=point_x
    test1=output(last_point_x,point_y)
    gradient = (output(point_x+step, point_y)-test1) / step
    point_x=point_x+step*np.sign(gradient)
    #需要测量值
    i=i+1
    if abs(output(point_x,point_y)-test1) < error:
    # print(point_x)
    # point_history.append((point_x,point_y))
    # score_history.append(output(point_x,point_y))
    break
    #固定x轴,y轴单轴的最大值扫描
    while point_y<fany[1] and point_y>fany[0] and i<number:
    last_point_y = point_y
    test2 = output(point_x, point_y)
    gradient = (output(point_x, point_y+step) - test2) / step
    point_y = point_y + step * np.sign(gradient)
    #需要测量值
    i=i+1
    if abs(output(point_x,point_y)-test2) < error:
    # point_history.append((point_x, point_y))
    # score_history.append(output(point_x, point_y))
    # print(point_y)
    break
    #判断一次x,y轴迭代之后指标值的差距是否足够小,达到阈值之后默认已经到最高点或者接近最高点
    # if abs(output(point_x, point_y) - output(x0, y0)) < 0.0000001 or abs(point_x-x0) < 0.00001:
    # break
    #i=i+1
    print(point_x, point_y) #输出最终最高点的位置坐标x和y
    print(i)


    if __name__=="__main__":
    t1=time.time()
    test(9,7,[-5,10],[-3,8])
    print(time.time()-t1)


  • 相关阅读:
    GitHub Pages 绑定二级域名
    JS正则表达式(JavaScript regular expression)
    天猫魔盒远程安装APP
    'msbuild.exe' 不是内部或外部命令,也不是可运行的程序
    Jenkins自动更新与数据备份
    Jenkins插件无法更新、Jenkins插件不能下载问题解决
    安全测试工具wapiti的安装和使用(2)命令及参数解释
    安全测试工具wapiti的安装和使用(1)安装
    Jenkins远程构建和发布,基于IIS服务器(.netCore+vue)(三)
    Jmeter报错:java.net.ConnectException: Connection timed out: connect
  • 原文地址:https://www.cnblogs.com/Yanjy-OnlyOne/p/13567535.html
Copyright © 2011-2022 走看看