zoukankan      html  css  js  c++  java
  • 机器学习:Python实现lms中的学习率的退火算法

    '''
       算法:lms学习率的退火算法
       解决的问题:学习率不变化,收敛速度较慢的情况
       思路:由初始解和控制参数初值开始,对当前解重复进行"产生新解-->计算目标函数差-->
       接受或舍弃"的迭代,并逐步衰减控制参数,算法终结时的当前解即为所得近似最优解
    '''
    
    '''
       变量约定:大写表示矩阵或数组,小写表示数字
       X:表示数组或者矩阵
       x:表示对应数组或矩阵的某个值
    '''
    
    import numpy as np
    import math
    a0=0.1  ##学习率初始值 0<a<1
    a=0.0   ##学习率变量
    r=0.2   ##可调参数,改善退火曲线的形态
    X=np.array([[1,1,6],[1,2,12],[1,3,9],[1,8,24]]) ##输入矩阵
    D=np.array([1,1,-1,-1])  ##期望输出结果矩阵
    W=np.array([1,0,0])   ##权重向量
    expect_e=0.005 ##期望误差
    maxtrycount=20 ##最大尝试次数
    cnt=0    ##当前循环次数
    
    ##硬限幅函数(即标准,这个比较简单:输入v大于0,返回1.小于等于0返回-1)
    
    def sgn(v):
        if v>0:
            return 1
        else:
            return -1 
        
    ##读取实际输出   
    '''
        这里是两个向量相乘,对应的数学公式:
        a(m,n)*b(p,q)=m*p+n*q
        在下面的函数中,当循环中xn=1时(此时W=([0.1,0.1])):
        np.dot(W.T,x)=(1,1)*(0.1,0.1)=1*0.1+1*0.1=0.2>0 ==>sgn 返回1
    '''
    def get_v(W,x):
        return sgn(np.dot(W.T,x))##dot表示两个矩阵相乘
    
    
    ##读取误差值
    def get_e(W,x,d):
        return d-get_v(W,x)
    
    ##权重计算函数(批量修正)
    '''
      对应数学公式: w(n+1)=w(n)+a*x(n)*e
      对应下列变量的解释:
      w(n+1) <= neww 的返回值
      w(n)   <=oldw(旧的权重向量)
      a      <= a(学习率,范围:0<a<1)
      x(n)   <= x(输入值)
      e      <= 误差值或者误差信号
    '''
    
    '''
        核心:学习率的计算(有多种退火方式,这里选取其中一个)
        对应的数学公式:(非标准数学符号)
        m(n)=m0/(1+n/t)
        参数解释:
        m(n)  <= 当前学习率(或者初始解)
        m0    <= 学习率初始值(也叫控制参数)
        n     <= 循环次数
        t     <= 可调参数,改善退火曲线的形态 
        
    '''
    def neww(oldW,d,x):
        e=get_e(oldW,x,d)
        #a=a0/(1+float(cnt)/r)
        a=a0/(1+float(cnt)*r)
        w=oldW+a*x*e
        return (w,e)
    
    
    ##修正权值
    '''
        此循环的原理:
        权值修正原理(批量修正)==>神经网络每次读入一个样本,进行修正,
            达到预期误差值或者最大尝试次数结束,修正过程结束   
    '''
    
    while True:
        err=0
        i=0
        for xn in X:        
            W,e=neww(W,D[i],xn)
            i+=1
            err+=pow(e,2)  ##lms算法的核心步骤,即:MES
        err=math.sqrt(err)  ##与lms算法有区别的地方,求开方最小
        cnt+=1
        print(u"第 %d 次调整后的权值:"%cnt)
        print(W)
        print(u"误差:%f"%err)
        if err<expect_e or cnt>=maxtrycount:
            break
    
    print("最后的权值:",W.T)
    
    ##输出结果
    print("开始验证结果...")
    for xn in X:
        print("D%s and W%s =>%d"%(xn,W.T,get_v(W,xn)))
    
    
    ##测试准确性:
    
    
    print("开始测试...")
    test=np.array([1,9,27])
    print("D%s and W%s =>%d"%(test,W.T,get_v(W,test)))
    test=np.array([1,11,66])
    print("D%s and W%s =>%d"%(test,W.T,get_v(W,test)))

      输出结果:

    第 1 次调整后的权值:
    [ 0.8 -0.6 -1.8]
    误差:2.0000002 次调整后的权值:
    [ 0.96666667 -0.6        -0.3       ]
    误差:3.4641023 次调整后的权值:
    [ 0.96666667 -0.88571429 -0.72857143]
    误差:2.8284274 次调整后的权值:
    [ 0.96666667 -1.88571429 -2.60357143]
    误差:4.0000005 次调整后的权值:
    [ 1.18888889 -1.55238095 -0.60357143]
    误差:2.8284276 次调整后的权值:
    [ 1.28888889 -1.55238095  0.29642857]
    误差:3.4641027 次调整后的权值:
    [ 1.28888889 -1.55238095  0.29642857]
    误差:0.000000
    最后的权值: [ 1.28888889 -1.55238095  0.29642857]
    开始验证结果...

        这次调整r值的结果是7次训练得出最优解,同样的数据用固定学习率的lms算法要8次训练。在调整r值的实验中,最快的一次是2次得出最优解。

  • 相关阅读:
    asp+access win2008php+mysql /dedecms 配置总结
    js获取页面元素位置函数(跨浏览器)
    Extjs 4 小记
    小总结
    新浪微博 page应用 自适应高度设定 终于找到解决方法
    常用的三层架构设计(转载)
    http://www.jeasyui.com/
    http://j-ui.com/
    日期编辑器MooTools DatePicker
    android布局
  • 原文地址:https://www.cnblogs.com/lc1217/p/6543775.html
Copyright © 2011-2022 走看看