zoukankan      html  css  js  c++  java
  • 常见优化算法统一框架下的实现:最速下降法,partan加速的最速下降法,共轭梯度法,牛顿法,拟牛顿法,黄金分割法,二次插值法

    常见优化算法实现

    这里实现的主要算法有:

    一维搜索方法:

    黄金分割法

    二次差值法

    多维搜索算法

    最速下降法

    partan加速的最速下降法

    共轭梯度法

    牛顿法

    拟牛顿法

    使用函数表示一个用于优化的目标,包括其梯度函数和hessian矩阵函数

    import numpy as np
    import math
    
    #用于测试的一个多元函数的例子
    def f(x):
        return (x[0]-1)**2+5*(x[1]-5)**2+(x[2]-1)**2+5*(x[3]-5)**2
    
    #f(x)函数的gradient向量计算函数
    def g(x):
        return np.array([2*(x[0]-1),10*(x[1]-5),2*(x[2]-1),10*(x[3]-5)])
    
    #f(x)函数的hessian矩阵的逆矩阵计算函数
    def hi(x=None):
        h=[1/2,1/10,1/2,1/10]
        return np.diag(h)
    

    拟牛顿法

    def quasi_newton(f=f,x0=np.zeros(4),gradient=g,acc=0.001):
        k=0
        x=x0
        xp=None
        hpk=None
        gpk=None
        
        while True:
            gk=gradient(x)
            #print(gk)
            if np.sum(gk**2)<=acc:
                #print("迭代 %d 次"%(k+1))
                return x,np.round(f(x),5)
            if k==0:
                hik=np.eye(x0.shape[0])
            else:
                dx=x-xp
                dg=gk-gpk
                
                temp = (dx-np.dot(hpk,dg)).reshape((-1,1))
                hik=hpk + np.dot(temp,temp.transpose())/(np.dot(temp.transpose(),dg.reshape((-1,1))))
                #print(hik)
                
            pk=-1*np.dot(hik,gk)
            alpha,y=quadraticInterploation(lambda alpha:(f(alpha*pk+x)),0,10,0.001)
            #更新变量
            x=alpha*pk+x
            hpk=hik
            xp=x
            gpk=gk
            k+=1
    

    共轭方向法

    def conjugate_direction(f=f,x0=np.zeros(4),gradient=g,acc=0.001):
        k=0
        x=x0
        #设置初值
        gpk=x0
        ppk=x0
        
        while True:
            gk=gradient(x)
            #print(gk)
            if np.sum(gk**2)<=acc:
                #print("迭代 %d 次"%(k+1))
                return x,np.round(f(x),5)
            if k==0:
                pk=-1*gk
            else:
                betak=np.sum(gk*gk)/np.sum(gpk*gpk)
                pk=-1*gk+betak*ppk
            #lambda表达式可以使用上层函数中的变量,这样对于不同的上下文,就是不同的函数
            alpha,y=quadraticInterploation(lambda alpha:(f(alpha*pk+x)),0,10,0.001)
            x=alpha*pk+x
            ppk=pk
            gpk=gk
            k+=1
    

    最速下降法

    #最速下降法
    def steepestDescent(f=f,x0=np.zeros(4),gradient=g,acc=0.001):
        k=0
        x=x0
        while True:
            gk=gradient(x)
            pk=-1*gk
            if np.sum(gk**2)<=acc:
                #print("迭代 %d 次"%(k+1))
                return x,f(x)
            #lambda表达式可以使用上层函数中的变量,这样对于不同的上下文,就是不同的函数
            alpha,y=quadraticInterploation(lambda alpha:(f(alpha*pk+x)),0,10,0.001)
            x=alpha*pk+x
            k+=1
    

    牛顿法

    def newton(f=f,x0=np.zeros(4),gradient=g,hessian=hi,acc=0.001):
        k=0
        x=x0
        while True:
            gk=gradient(x)
            hik=hessian(x)
            pk=-1*np.dot(gk,hik)
            if np.sum(gk**2)<=acc:
                #print("迭代 %d 次"%(k+1))
                return x,f(x)
            #lambda表达式可以使用上层函数中的变量,这样对于不同的上下文,就是不同的函数
            alpha,y=quadraticInterploation(lambda alpha:(f(alpha*pk+x)),0,10,0.001)
            x=alpha*pk+x
            k+=1
    

    使用partan加速的最速下降法

    def partan(f=f,x0=np.zeros(4),gradient=g,acc=0.001,N=3):
        k=0
        x=x0
        xp1=x0
        xp2=x0
        while True:
            if k>=N and k%3==0:
                pk=x-xp2
            else:
                gk=gradient(x)
                pk=-1*gk
            if np.sum(pk**2)<=acc:
                #print("迭代 %d 次"%(k+1))
                return x,f(x)
            #lambda表达式可以使用上层函数中的变量,这样对于不同的上下文,就是不同的函数
            alpha,y=quadraticInterploation(lambda alpha:(f(alpha*pk+x)),0,10,0.001)
            xp2=xp1
            xp1=x
            x=alpha*pk+x
            k+=1
    

    一维搜索的黄金分割方法

    def goldenSegmantation(f,a,b,acc):
        x1=a+0.382*(b-a)
        x2=b-(x1-a)
        R=f(x1);G=f(x2)
        #因为浮点数的舍入误差,可能导致a,b的大小逆转
        while abs(b-a)>acc and a<=x1<x2<=b:
           #print(abs(b-a))
            if R>G:
                a=x1
                x1=x2
                R=G
                x2=b-(x1-a)
                G=f(x2)
            else:
                b=x2
                x2=x1
                G=R
                x1=a+(b-x2)
                R=f(x1)
        return (a+b)/2.0,f(((a+b)/2.0))
    

    一维搜索的二次差值方法

    def quadraticInterploation(f,a,b,acc):
        assert(a<b)
        x1=a;x2=(a+b)/2;x3=b
        f1=f(x1);f2=f(x2);f3=f(x3)
        while True:
            c1=(f3-f1)/(x3-x1);c2=((f2-f1)/(x2-x1)-c1)/(x2-x3)
            xp=0.5*(x1+x3-c1/c2)
            fp=f(xp)
            if abs(xp-x2)<acc or not a<=x1<x2<x3<=b:
                if fp<f2:
                    return xp,fp
                else:
                    return x2,f2
            if x2<xp:
                if f2<fp:
                    x3=xp;f3=fp
                else:
                    x1=x2;f1=f2
                    x2=xp;f2=fp
            else:
                if f2<fp:
                    x1=xp;f1=fp
                else:
                    x3=x2;f3=f2
                    x2=xp;f2=fp
    

    测试一维搜索方法

    %timeit(goldenSegmantation(lambda x:(x**4-5),-1,1,0.0001))
    %timeit(quadraticInterploation(lambda x:(x**4-5),-1,1,0.00001))
    %timeit(goldenSegmantation(lambda x:(x**2-5*x+6),-10,10,0.00000005))
    %timeit(quadraticInterploation(lambda x:(x**2-5*x+6),-10,10,0.000001))
    %timeit(goldenSegmantation(math.sin,-1*math.pi,0,0.000001))
    %timeit(quadraticInterploation(math.sin,-1*math.pi,0,0.0000001))                
    
    11.3 µs ± 58.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    3.09 µs ± 18.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    12.5 µs ± 47.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    5.44 µs ± 27.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    7.97 µs ± 33.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    2.05 µs ± 19.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    

    结果分析

    对于不同的目标函数,二次插值的速度均大于黄金分割方法

    测试高维搜索方法

    %timeit steepestDescent()
    %timeit partan()
    %timeit conjugate_direction()
    %timeit newton()
    %timeit quasi_newton()
    
    236 µs ± 2.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    297 µs ± 2.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    197 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    125 µs ± 276 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    224 µs ± 1.28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    

    结果分析

    从结果看出来,partan加速方法相比最速下降方法并没有什么优势,主要原因是目标函数太简单,迭代次数太少

    拟牛顿法相比最速下降法也没有什么优势,我想也是基于同样的原因

  • 相关阅读:
    MyBatis3.2从入门到精通第一章
    (转)浅析Java中的访问权限控制
    mysql添加索引命令
    (转)浅谈Java中的对象和对象引用
    (转)String、StringBuffer与StringBuilder之间区别
    (转)浅谈Java中的equals和==
    Java并发编程:Lock
    Java并发编程:synchronized
    安装MySQL
    Excel常用函数
  • 原文地址:https://www.cnblogs.com/wbwang/p/7841447.html
Copyright © 2011-2022 走看看