zoukankan      html  css  js  c++  java
  • Adam (1)

    • 算法特征
      ①. 梯度凸组合控制迭代方向; ②. 梯度平方凸组合控制迭代步长; ③. 各优化变量自适应搜索.
    • 算法推导
      Part Ⅰ 算法细节
      拟设目标函数符号为$J$, 则梯度表示如下,
      egin{equation}
      g = abla J
      label{eq_1}
      end{equation}
      参考Momentum Gradient, 对梯度凸组合控制迭代方向first momentum,
      egin{equation}
      m_{k} = eta_1m_{k-1} + (1 - eta_1)g_{k}
      label{eq_2}
      end{equation}
      其中, $eta_1$是凸组合系数, 也是指数衰减率.
      参考RMSProp, 对梯度平方凸组合控制迭代步长second raw momentum,
      egin{equation}
      v_{k} = eta_2v_{k-1} + (1 - eta_2)g_{k}odot g_{k}
      label{eq_3}
      end{equation}
      其中, $eta_2$是凸组合系数, 也是指数衰减率.
      由于first momentum与second raw momentum均初始化为0, 分别以如下方式修正以降低凸组合系数对初始迭代的影响,
      egin{gather}
      hat{m}_{k} = frac{m_{k}}{1 - eta_1^{k}}label{eq_4} \
      hat{v}_{k} = frac{v_{k}}{1 - eta_2^{k}}label{eq_5}
      end{gather}
      不失一般性, 令第$k$步迭代形式如下,
      egin{equation}
      x_{k+1} = x_k + alpha_kd_k
      label{eq_6}
      end{equation}
      其中, $alpha_k$、$d_k$分别代表第$k$步迭代步长与迭代方向, 且
      egin{gather}
      alpha_k = frac{alpha}{sqrt{hat{v}_k} + epsilon}label{eq_7} \
      d_k = -hat{m}_klabel{eq_8}
      end{gather}
      其中, $alpha$代表步长参数, $epsilon$取值足够小正数避免迭代步长分母为0.
      Part Ⅱ 算法流程
      初始化步长参数$alpha$、足够小正数$epsilon$、指数衰减率$eta_1$、指数衰减率$eta_2$
      初始化收敛判据$zeta$、迭代起点$x_1$
      计算当前梯度值$g_1= abla J(x_1)$, 令: 一阶矩$m_0 = 0$, 二阶矩$v_0 = 0$, $k = 1$, 重复以下步骤,
        step1: 如果$|g_k| < zeta$, 收敛, 迭代停止
        step2: 更新一阶矩$m_k = eta_1m_{k-1} + (1 - eta_1)g_{k}$
        step3: 更新二阶矩$v_k = eta_2v_{k-1} + (1 - eta_2)g_{k}odot g_{k}$
        step4: 计算一阶矩修正$displaystyle hat{m}_{k} = frac{m_{k}}{1 - eta_1^{k}}$
        step5: 计算二阶矩修正$displaystyle hat{v}_{k} = frac{v_{k}}{1 - eta_2^{k}}$
        step6: 计算迭代步长$displaystyle alpha_k = frac{alpha}{sqrt{hat{v}_k} + epsilon}$
        step7: 计算迭代方向$d_k = -hat{m}_k$
        step8: 更新迭代点$x_{k+1} = x_k + alpha_kd_k$
        step9: 更新梯度值$g_{k+1}= abla J(x_{k+1})$
        step10: 令$k = k+1$, 转step1
    • 代码实现
      现以如下无约束凸优化问题为例进行算法实施,
      egin{equation*}
      minquad 5x_1^2 + 2x_2^2 + 3x_1 - 10x_2 + 4
      end{equation*}
      Adam实现如下,
        1 # Adam之实现
        2 
        3 import numpy
        4 from matplotlib import pyplot as plt
        5 
        6 
        7 # 目标函数0阶信息
        8 def func(X):
        9     funcVal = 5 * X[0, 0] ** 2 + 2 * X[1, 0] ** 2 + 3 * X[0, 0] - 10 * X[1, 0] + 4
       10     return funcVal
       11     
       12     
       13 # 目标函数1阶信息
       14 def grad(X):
       15     grad_x1 = 10 * X[0, 0] + 3
       16     grad_x2 = 4 * X[1, 0] - 10
       17     gradVec = numpy.array([[grad_x1], [grad_x2]])
       18     return gradVec
       19     
       20     
       21 # 定义迭代起点
       22 def seed(n=2):
       23     seedVec = numpy.random.uniform(-100, 100, (n, 1))
       24     return seedVec
       25     
       26     
       27 class Adam(object):
       28     
       29     def __init__(self, _func, _grad, _seed):
       30         '''
       31         _func: 待优化目标函数
       32         _grad: 待优化目标函数之梯度
       33         _seed: 迭代起始点
       34         '''
       35         self.__func = _func
       36         self.__grad = _grad
       37         self.__seed = _seed
       38         
       39         self.__xPath = list()
       40         self.__JPath = list()
       41         
       42         
       43     def get_solu(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1.e-8, zeta=1.e-6, maxIter=3000000):
       44         '''
       45         获取数值解,
       46         alpha: 步长参数
       47         beta1: 一阶矩指数衰减率
       48         beta2: 二阶矩指数衰减率
       49         epsilon: 足够小正数
       50         zeta: 收敛判据
       51         maxIter: 最大迭代次数
       52         '''
       53         self.__init_path()
       54         
       55         x = self.__init_x()
       56         JVal = self.__calc_JVal(x)
       57         self.__add_path(x, JVal)
       58         grad = self.__calc_grad(x)
       59         m, v = numpy.zeros(x.shape), numpy.zeros(x.shape)
       60         for k in range(1, maxIter + 1):
       61             # print("k: {:3d},   JVal: {}".format(k, JVal))
       62             if self.__converged1(grad, zeta):
       63                 self.__print_MSG(x, JVal, k)
       64                 return x, JVal, True
       65             
       66             m = beta1 * m + (1 - beta1) * grad
       67             v = beta2 * v + (1 - beta2) * grad * grad
       68             m_ = m / (1 - beta1 ** k)
       69             v_ = v / (1 - beta2 ** k)
       70             
       71             alpha_ = alpha / (numpy.sqrt(v_) + epsilon)
       72             d = -m_
       73             xNew = x + alpha_ * d
       74             JNew = self.__calc_JVal(xNew)
       75             self.__add_path(xNew, JNew)
       76             if self.__converged2(xNew - x, JNew - JVal, zeta ** 2):
       77                 self.__print_MSG(xNew, JNew, k + 1)
       78                 return xNew, JNew, True
       79                 
       80             gNew = self.__calc_grad(xNew)
       81             x, JVal, grad = xNew, JNew, gNew
       82         else:
       83             if self.__converged1(grad, zeta):
       84                 self.__print_MSG(x, JVal, maxIter)
       85                 return x, JVal, True
       86                 
       87         print("Adam not converged after {} steps!".format(maxIter))
       88         return x, JVal, False
       89         
       90         
       91     def get_path(self):
       92         return self.__xPath, self.__JPath
       93             
       94             
       95     def __converged1(self, grad, epsilon):
       96         if numpy.linalg.norm(grad, ord=numpy.inf) < epsilon:
       97             return True
       98         return False
       99         
      100         
      101     def __converged2(self, xDelta, JDelta, epsilon):
      102         val1 = numpy.linalg.norm(xDelta, ord=numpy.inf)
      103         val2 = numpy.abs(JDelta)
      104         if val1 < epsilon or val2 < epsilon:
      105             return True
      106         return False
      107         
      108         
      109     def __print_MSG(self, x, JVal, iterCnt):
      110         print("Iteration steps: {}".format(iterCnt))
      111         print("Solution:
      {}".format(x.flatten()))
      112         print("JVal: {}".format(JVal))
      113         
      114         
      115     def __calc_JVal(self, x):
      116         return self.__func(x)
      117         
      118         
      119     def __calc_grad(self, x):
      120         return self.__grad(x)
      121         
      122         
      123     def __init_x(self):
      124         return self.__seed
      125         
      126         
      127     def __init_path(self):
      128         self.__xPath.clear()
      129         self.__JPath.clear()
      130         
      131         
      132     def __add_path(self, x, JVal):
      133         self.__xPath.append(x)
      134         self.__JPath.append(JVal)
      135         
      136                 
      137 class AdamPlot(object):
      138     
      139     @staticmethod
      140     def plot_fig(adamObj):
      141         x, JVal, tab = adamObj.get_solu(0.1)
      142         xPath, JPath = adamObj.get_path()
      143         
      144         fig = plt.figure(figsize=(10, 4))
      145         ax1 = plt.subplot(1, 2, 1)
      146         ax2 = plt.subplot(1, 2, 2)
      147         
      148         ax1.plot(numpy.arange(len(JPath)), JPath, "k.", markersize=1)
      149         ax1.plot(0, JPath[0], "go", label="starting point")
      150         ax1.plot(len(JPath)-1, JPath[-1], "r*", label="solution")
      151         
      152         ax1.legend()
      153         ax1.set(xlabel="$iterCnt$", ylabel="$JVal$")
      154         
      155         x1 = numpy.linspace(-100, 100, 300)
      156         x2 = numpy.linspace(-100, 100, 300)
      157         x1, x2 = numpy.meshgrid(x1, x2)
      158         f = numpy.zeros(x1.shape)
      159         for i in range(x1.shape[0]):
      160             for j in range(x1.shape[1]):
      161                 f[i, j] = func(numpy.array([[x1[i, j]], [x2[i, j]]]))
      162         ax2.contour(x1, x2, f, levels=36)
      163         x1Path = list(item[0] for item in xPath)
      164         x2Path = list(item[1] for item in xPath)
      165         ax2.plot(x1Path, x2Path, "k--", lw=2)
      166         ax2.plot(x1Path[0], x2Path[0], "go", label="starting point")
      167         ax2.plot(x1Path[-1], x2Path[-1], "r*", label="solution")
      168         ax2.set(xlabel="$x_1$", ylabel="$x_2$")
      169         ax2.legend()
      170                 
      171         fig.tight_layout()
      172         # plt.show()
      173         fig.savefig("plot_fig.png")
      174 
      175         
      176         
      177 if __name__ == "__main__":
      178     adamObj = Adam(func, grad, seed())
      179     
      180     AdamPlot.plot_fig(adamObj)
      View Code
    • 结果展示
    • 使用建议
      ①. 局部二阶矩求和一定程度上反应了局部的曲率信息, 用以近似并替代Hessian矩阵是合理的;
      ②. 文献中初始化参数推荐$alpha=0.001, eta_1=0.9, eta_2=0.999, epsilon=10^{-8}$, 实际根据需要优先调整步长参数$alpha$.
    • 参考文档
      Kingma D P, Ba J. Adam: A method for stochastic optimization[J]. arXiv preprint arXiv:1412.6980, 2014.
  • 相关阅读:
    python笔记——调试和异常处理
    [算法学习] 线段树,树状数组,数堆,笛卡尔树
    【cpp】G++中不支持static_cast?
    【生活感想】不够淡定
    数值线性代数小结
    伪逆
    统计机器学习
    Numerical Methods with MATLAB(1)
    吐槽iOS国际化:关于NSLocalizedString的使用
    iOS 沙盒路径操作:新建/删除文件和文件夹
  • 原文地址:https://www.cnblogs.com/xxhbdk/p/15063793.html
Copyright © 2011-2022 走看看