zoukankan      html  css  js  c++  java
  • PyTorch——(6)2D函数优化实例

    在这里插入图片描述最小值点有4个
    在这里插入图片描述

    import numpy as np
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib import pyplot as plt
    import torch
    
    
    
    def himmelblau(x):
        return (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2
    
    
    # x = np.arange(-6, 6, 0.1)
    # y = np.arange(-6, 6, 0.1)
    # print('x,y range:', x.shape, y.shape)
    # X, Y = np.meshgrid(x, y)
    # print('X,Y maps:', X.shape, Y.shape)
    # Z = himmelblau([X, Y])
    #
    # fig = plt.figure('himmelblau')
    # ax = fig.gca(projection='3d')
    # ax.plot_surface(X, Y, Z)
    # ax.view_init(60, -30)
    # ax.set_xlabel('x')
    # ax.set_ylabel('y')
    # plt.show()
    
    
    # [1., 0.], [-4, 0.], [4, 0.]
    x = torch.tensor([-4., 0.], requires_grad=True)
    # 设置优化器 (待优化变量,lr=学习率)
    optimizer = torch.optim.Adam([x], lr=1e-3)
    for step in range(20000):
    
        pred = himmelblau(x)
        # 把所有参数梯度缓存器置零
        optimizer.zero_grad()
        # 计算反向传播梯度
        pred.backward()
        optimizer.step()
    
        if step % 2000 == 0:
            print ('step {}: x = {}, f(x) = {}'
                   .format(step, x.tolist(), pred.item()))
    
  • 相关阅读:
    实验10 使用PBR实现策略路由
    实验9 使用route-policy控制路由
    实验8 filter-policy过滤路由
    实验7 ISIS多区域配置
    实验6 IS-IS基本配置
    MySQL复制表
    mysql数据备份
    mysql 创建用户,授权
    数据库
    mysql 修改文件记录:
  • 原文地址:https://www.cnblogs.com/long5683/p/14702386.html
Copyright © 2011-2022 走看看