zoukankan      html  css  js  c++  java
  • 基于pytorch对函数进行极值求解

     1 import numpy as np
     2 from mpl_toolkits.mplot3d import Axes3D
     3 import matplotlib.pyplot as plt
     4 from matplotlib.colors import LinearSegmentedColormap
     5 
     6 # 待求极值的函数
     7 def himmelblau(t):# t[0]-->X; t[1]-->Y.
     8     return (t[0] ** 2 + t[1] - 11) ** 2 + (t[0] + t[1] ** 2 - 7) ** 2
     9 
    10 x = np.arange(-6, 6, 0.1)
    11 y = np.arange(-6, 6, 0.1)
    12 X, Y = np.meshgrid(x, y)
    13 Z = himmelblau([X, Y])
    14 fig = plt.figure()
    15 ax = fig.add_subplot(projection='3d')# ax = fig.gca(projection='3d') # ---> was deprecated in Matplotlib 3.4
    16 ax.plot_surface(X, Y, Z)
    17 ax.view_init(60, -30)
    18 ax.set_xlabel('x')
    19 ax.set_ylabel('y')
    20 fig.show()
    21 plt.show()
    22 
    23 # function test
    24 def jeshy(t):
    25     return t*3+10
    26 
    27 import torch
    28 x = torch.tensor([0., 0.], requires_grad=True)
    29 optimizer = torch.optim.Adam([x, ])# optim.Adam([var1, var2], lr=0.0001)# 优化器设置 ,并传入模型参数和相应的学习率
    30 for step in range(20001):
    31     f = himmelblau(x)# 前向传播
    32     if step > 0:
    33         optimizer.zero_grad()# 反向传播与优化# 清空上一步的残余更新参数值
    34         f.backward(retain_graph=True)# 反向传播与优化# 反向传播
    35         optimizer.step()# 反向传播与优化# 将参数更新值施加到函数f的parameters上
    36     # f = jeshy(f)
    37     if step % 1000 == 0:# 每迭代一定步骤,打印结果值
    38         print('step:{}, x = {}, value = {}'.format(step, x.tolist(), f))

     输出:

    step:0, x = [0.0, 0.0], value = 170.0
    step:1000, x = [1.270142912864685, 1.1183991432189941], value = 88.53223419189453
    step:2000, x = [2.332378387451172, 1.9535712003707886], value = 13.766233444213867
    step:3000, x = [2.8519949913024902, 2.114161968231201], value = 0.6711398363113403
    step:4000, x = [2.981964111328125, 2.0271568298339844], value = 0.014927156269550323
    step:5000, x = [2.9991261959075928, 2.0014777183532715], value = 3.9870232285466045e-05
    step:6000, x = [2.999983549118042, 2.0000221729278564], value = 1.1074007488787174e-08
    step:7000, x = [2.9999899864196777, 2.000013589859009], value = 4.150251697865315e-09
    step:8000, x = [2.9999938011169434, 2.0000083446502686], value = 1.5572823031106964e-09
    step:9000, x = [2.9999964237213135, 2.000005006790161], value = 5.256879376247525e-10
    step:10000, x = [2.999997854232788, 2.000002861022949], value = 1.8189894035458565e-10
    step:11000, x = [2.9999988079071045, 2.0000014305114746], value = 5.547917680814862e-11
    step:12000, x = [2.9999992847442627, 2.0000009536743164], value = 1.6370904631912708e-11
    step:13000, x = [2.999999523162842, 2.000000476837158], value = 5.6843418860808015e-12
    step:14000, x = [2.999999761581421, 2.000000238418579], value = 1.8189894035458565e-12
    step:15000, x = [3.0, 2.0], value = 0.0
    step:16000, x = [3.0, 2.0], value = 0.0
    step:17000, x = [3.0, 2.0], value = 0.0
    step:18000, x = [3.0, 2.0], value = 0.0
    step:19000, x = [3.0, 2.0], value = 0.0
    step:20000, x = [3.0, 2.0], value = 0.0

    个人学习记录
  • 相关阅读:
    我工作三年了,该懂并发了!
    代理,一文入魂
    非典型算法题,用程序和电脑玩一个游戏
    详解command设计模式,解耦操作和回滚
    matplotlib画图教程,设置坐标轴标签和间距
    详解工程师不可不会的LRU缓存淘汰算法
    详解深度学习感知机原理
    详解gitignore的使用方法,让你尽情使用git add .
    算法题 | 你追我,如果你追到我……那就算你赢了
    险些翻车,差一点没做出来的基础算法题
  • 原文地址:https://www.cnblogs.com/jeshy/p/15464766.html
Copyright © 2011-2022 走看看