zoukankan      html  css  js  c++  java
  • (PSO-BP)结合粒子群的神经网络算法以及matlab实现

    原理:
               PSO(粒子群群算法):可以在全局范围内进行大致搜索,得到一个初始解,以便BP接力
               BP(神经网络):梯度搜素,细化能力强,可以进行更仔细的搜索。
    数据:对该函数((2.1*(1-x+2*x.^2).*exp(-x.^2/2))+sin(x)+x','x')[-5,5]进行采样,得到30组训练数据,拟合该网络。

         神经网络结构设置:   该网络结构为,1-7-1结构,即输入1个神经元,中间神经元7个,输出1个神经元

           程序步骤:

                第一步:先采用抽取30组数据,包括输入和输出

               第一步:运行粒子群算法,进行随机搜索,选择一个最优的解,该解的维数为22维。

               第二步:在;粒子群的解基础上进行细化搜索

    程序代码:

               

    clc                         
    clear                
      tic              
    SamNum=30;                  
                    
    HiddenNum=7;          
    InDim=1;                
    OutDim=1;      
    
    load train_x
    load train_f
    
    a=train_x';
    d=train_f';
    
    p=[a];  
    t=[d];      
    [SamIn,minp,maxp,tn,mint,maxt]=premnmx(p,t); 
             
    NoiseVar=0.01;                 
    Noise=NoiseVar*randn(1,SamNum);   
    SamOut=tn + Noise;                 
        
       SamIn=SamIn';
       SamOut=SamOut';
    
    MaxEpochs=60000;                       
    lr=0.025;                                      
    E0=0.65*10^(-6);                              
    
    %%
    %the begin of PSO
       
    E0=0.001;
    Max_num=500;
    particlesize=200;
    c1=1;
    c2=1;
    w=2;
    vc=2;
    vmax=5;
    dims=InDim*HiddenNum+HiddenNum+HiddenNum*OutDim+OutDim;
    x=-4+7*rand(particlesize,dims);
    v=-4+5*rand(particlesize,dims);
    f=zeros(particlesize,1);
     %%
     for jjj=1:particlesize
         trans_x=x(jjj,:);
        W1=zeros(InDim,HiddenNum);   
        B1=zeros(HiddenNum,1);    
        W2=zeros(HiddenNum,OutDim);           
        B2=zeros(OutDim,1);
    
        W1=trans_x(1,1:HiddenNum);
        B1=trans_x(1,HiddenNum+1:2*HiddenNum)'; 
        W2=trans_x(1,2*HiddenNum+1:3*HiddenNum)';           
        B2=trans_x(1,3*HiddenNum+1); 
        Hiddenout=logsig(SamIn*W1+repmat(B1',SamNum,1));
        Networkout=Hiddenout*W2+repmat(B2',SamNum,1);
        Error=Networkout-SamOut;                       
        SSE=sumsqr(Error)  
        
         f(jjj)=SSE;
     end
    personalbest_x=x;
    personalbest_f=f;
    [groupbest_f i]=min(personalbest_f);
    groupbest_x=x(i,:);
    for j_Num=1:Max_num
          vc=(5/3*Max_num-j_Num)/Max_num;
        %%     
           v=w*v+c1*rand*(personalbest_x-x)+c2*rand*(repmat(groupbest_x,particlesize,1)-x);
            for kk=1:particlesize
                  for  kk0=1:dims
                  if v(kk,kk0)>vmax
                         v(kk,kk0)=vmax;
                  else if v(kk,kk0)<-vmax
                          v(kk,kk0)=-vmax;
                      end
                  end
                  end
            end
            x=x+vc*v;
            %%
        for jjj=1:particlesize
                        trans_x=x(jjj,:);
                        W1=zeros(InDim,HiddenNum);   
                        B1=zeros(HiddenNum,1);    
                        W2=zeros(HiddenNum,OutDim);           
                        B2=zeros(OutDim,1);
    
                        W1=trans_x(1,1:HiddenNum);
                        B1=trans_x(1,HiddenNum+1:2*HiddenNum)'; 
                        W2=trans_x(1,2*HiddenNum+1:3*HiddenNum)';           
                        B2=trans_x(1,3*HiddenNum+1); 
                        Hiddenout=logsig(SamIn*W1+repmat(B1',SamNum,1));
                        Networkout=Hiddenout*W2+repmat(B2',SamNum,1);
                        Error=Networkout-SamOut;                       
                        SSE=sumsqr(Error);  
        
                       f(jjj)=SSE;
         
     end    
     %%
         for kk=1:particlesize
             if f(kk)<personalbest_f(kk)
                 personalbest_f(kk)=f(kk);
                 personalbest_x(kk)=x(kk);
             end
         end
         [groupbest_f0 i]=min(personalbest_f);
         
         if    groupbest_f0<groupbest_f
         groupbest_x=x(i,:);
         groupbest_f=groupbest_f0;
         end
         ddd(j_Num)=groupbest_f
    end
       str=num2str(groupbest_f);
        trans_x=groupbest_x;
        W1=trans_x(1,1:HiddenNum);
        B1=trans_x(1,HiddenNum+1:2*HiddenNum)'; 
        W2=trans_x(1,2*HiddenNum+1:3*HiddenNum)';           
        B2=trans_x(1,3*HiddenNum+1); 
    %the end of PSO
    %%
                                
    for i=1:MaxEpochs
        %%
        Hiddenout=logsig(SamIn*W1+repmat(B1',SamNum,1));
        Networkout=Hiddenout*W2+repmat(B2',SamNum,1);
        Error=Networkout-SamOut;                       
        SSE=sumsqr(Error)                          
    
        ErrHistory=[ SSE];
    
        if SSE<E0,break, end      
          dB2=zeros(OutDim,1);
          dW2=zeros(HiddenNum,OutDim);
                    for jj=1:HiddenNum  
                                 for k=1:SamNum
                                 dW2(jj,OutDim)=dW2(jj,OutDim)+Error(k)*Hiddenout(k,jj);
                                 end
                    end 
                   for k=1:SamNum
                                     dB2(OutDim,1)=dB2(OutDim,1)+Error(k);
                                 end        
         dW1=zeros(InDim,HiddenNum);
         dB1=zeros(HiddenNum,1);
      for ii=1:InDim
           for jj=1:HiddenNum
                      
                             for k=1:SamNum
                                     dW1(ii,jj)=dW1(ii,jj)+Error(k)*W2(jj,OutDim)*Hiddenout(k,jj)*(1-Hiddenout(k,jj))*(SamIn(k,ii));
                                     dB1(jj,1)=dB1(jj,1)+Error(k)*W2(jj,OutDim)*Hiddenout(k,jj)*(1-Hiddenout(k,jj));
    
                             end
                  end
      end
    
        W2=W2-lr*dW2;
        B2=B2-lr*dB2;
       
        W1=W1-lr*dW1;
        B1=B1-lr*dB1;
    end
    
    Hiddenout=logsig(SamIn*W1+repmat(B1',SamNum,1));
    Networkout=Hiddenout*W2+repmat(B2',SamNum,1);
        
    aa=postmnmx(Networkout,mint,maxt);             
    x=a;                                   
    newk=aa;                                    
              figure                           
    plot(x,d,'r-o',x,newk,'b--+') 
    legend('原始数据','训练后的数据');
    xlabel('x');ylabel('y');
    toc
    
          

    注:在(i5,8G,win7,64位)PC上的运行时间为30s左右。鉴于PSO带有概率性,可以多跑几次,看最佳的一次效果。

          

    转载于:https://www.cnblogs.com/jacksin/p/8835907.html

  • 相关阅读:
    第一只猫环境准备
    高阶Promise--async
    操作系统源码编译心得
    闭包的应用场景
    全栈微信小程序商城 学习笔记10.1 对更新收货地址接口做权限控制
    全栈微信小程序商城 学习笔记9.1 新建登录接口
    全栈微信小程序商城 学习笔记8.5 product分类商品接口编写
    全栈微信小程序商城 学习笔记8.4 category分类接口编写
    全栈微信小程序商城 学习笔记8.3 product最近新品接口编写
    全栈微信小程序商城 学习笔记8.2 theme详情接口编写
  • 原文地址:https://www.cnblogs.com/twodog/p/12137110.html
Copyright © 2011-2022 走看看