zoukankan      html  css  js  c++  java
  • 转载——关于bp神经网络

    一、BP神经网络的概念

        BP神经网络是一种多层的前馈神经网络,其主要的特点是:信号是前向传播的,而误差是反向传播的。具体来说,对于如下的只含一个隐层的神经网络模型:
    (三层BP神经网络模型)
    BP神经网络的过程主要分为两个阶段,第一阶段是信号的前向传播,从输入层经过隐含层,最后到达输出层;第二阶段是误差的反向传播,从输出层到隐含层,最后到输入层,依次调节隐含层到输出层的权重和偏置,输入层到隐含层的权重和偏置。

    二、BP神经网络的流程

        在知道了BP神经网络的特点后,我们需要依据信号的前向传播和误差的反向传播来构建整个网络。

    1、网络的初始化

        假设输入层的节点个数为,隐含层的节点个数为,输出层的节点个数为。输入层到隐含层的权重,隐含层到输出层的权重为,输入层到隐含层的偏置为,隐含层到输出层的偏置为。学习速率为,激励函数为。其中激励函数为取Sigmoid函数。形式为:

    2、隐含层的输出

        如上面的三层BP网络所示,隐含层的输出

    3、输出层的输出

    4、误差的计算

        我们取误差公式为:
    其中为期望输出。我们记,则可以表示为
    以上公式中,

    5、权值的更新

        权值的更新公式为:
    这里需要解释一下公式的由来:
    这是误差反向传播的过程,我们的目标是使得误差函数达到最小值,即,我们使用梯度下降法:
    • 隐含层到输出层的权重更新
    则权重的更新公式为:
    • 输入层到隐含层的权重更新
    其中
     
    则权重的更新公式为:

    6、偏置的更新

        偏置的更新公式为:
    • 隐含层到输出层的偏置更新
    则偏置的更新公式为:
    • 输入层到隐含层的偏置更新
    其中
     
    则偏置的更新公式为:

    7、判断算法迭代是否结束

        有很多的方法可以判断算法是否已经收敛,常见的有指定迭代的代数,判断相邻的两次误差之间的差别是否小于指定的值等等。

    三、实验的仿真

        在本试验中,我们利用BP神经网络处理一个四分类问题,最终的分类结果为:

    MATLAB代码

    主程序
    [plain] view plaincopy在CODE上查看代码片派生到我的代码片
     
    1. %% BP的主函数  
    2.   
    3. % 清空  
    4. clear all;  
    5. clc;  
    6.   
    7. % 导入数据  
    8. load data;  
    9.   
    10. %从1到2000间随机排序  
    11. k=rand(1,2000);  
    12. [m,n]=sort(k);  
    13.   
    14. %输入输出数据  
    15. input=data(:,2:25);  
    16. output1 =data(:,1);  
    17.   
    18. %把输出从1维变成4维  
    19. for i=1:2000  
    20.     switch output1(i)  
    21.         case 1  
    22.             output(i,:)=[1 0 0 0];  
    23.         case 2  
    24.             output(i,:)=[0 1 0 0];  
    25.         case 3  
    26.             output(i,:)=[0 0 1 0];  
    27.         case 4  
    28.             output(i,:)=[0 0 0 1];  
    29.     end  
    30. end  
    31.   
    32. %随机提取1500个样本为训练样本,500个样本为预测样本  
    33. trainCharacter=input(n(1:1600),:);  
    34. trainOutput=output(n(1:1600),:);  
    35. testCharacter=input(n(1601:2000),:);  
    36. testOutput=output(n(1601:2000),:);  
    37.   
    38. % 对训练的特征进行归一化  
    39. [trainInput,inputps]=mapminmax(trainCharacter');  
    40.   
    41. %% 参数的初始化  
    42.   
    43. % 参数的初始化  
    44. inputNum = 24;%输入层的节点数  
    45. hiddenNum = 50;%隐含层的节点数  
    46. outputNum = 4;%输出层的节点数  
    47.   
    48. % 权重和偏置的初始化  
    49. w1 = rands(inputNum,hiddenNum);  
    50. b1 = rands(hiddenNum,1);  
    51. w2 = rands(hiddenNum,outputNum);  
    52. b2 = rands(outputNum,1);  
    53.   
    54. % 学习率  
    55. yita = 0.1;  
    56.   
    57. %% 网络的训练  
    58. for r = 1:30  
    59.     E(r) = 0;% 统计误差  
    60.     for m = 1:1600  
    61.         % 信息的正向流动  
    62.         x = trainInput(:,m);  
    63.         % 隐含层的输出  
    64.         for j = 1:hiddenNum  
    65.             hidden(j,:) = w1(:,j)'*x+b1(j,:);  
    66.             hiddenOutput(j,:) = g(hidden(j,:));  
    67.         end  
    68.         % 输出层的输出  
    69.         outputOutput = w2'*hiddenOutput+b2;  
    70.           
    71.         % 计算误差  
    72.         e = trainOutput(m,:)'-outputOutput;  
    73.         E(r) = E(r) + sum(abs(e));  
    74.           
    75.         % 修改权重和偏置  
    76.         % 隐含层到输出层的权重和偏置调整  
    77.         dw2 = hiddenOutput*e';  
    78.         db2 = e;  
    79.           
    80.         % 输入层到隐含层的权重和偏置调整  
    81.         for j = 1:hiddenNum  
    82.             partOne(j) = hiddenOutput(j)*(1-hiddenOutput(j));  
    83.             partTwo(j) = w2(j,:)*e;  
    84.         end  
    85.           
    86.         for i = 1:inputNum  
    87.             for j = 1:hiddenNum  
    88.                 dw1(i,j) = partOne(j)*x(i,:)*partTwo(j);  
    89.                 db1(j,:) = partOne(j)*partTwo(j);  
    90.             end  
    91.         end  
    92.           
    93.         w1 = w1 + yita*dw1;  
    94.         w2 = w2 + yita*dw2;  
    95.         b1 = b1 + yita*db1;  
    96.         b2 = b2 + yita*db2;    
    97.     end  
    98. end  
    99.   
    100. %% 语音特征信号分类  
    101. testInput=mapminmax('apply',testCharacter',inputps);  
    102.   
    103. for m = 1:400  
    104.     for j = 1:hiddenNum  
    105.         hiddenTest(j,:) = w1(:,j)'*testInput(:,m)+b1(j,:);  
    106.         hiddenTestOutput(j,:) = g(hiddenTest(j,:));  
    107.     end  
    108.     outputOfTest(:,m) = w2'*hiddenTestOutput+b2;  
    109. end  
    110.   
    111. %% 结果分析  
    112. %根据网络输出找出数据属于哪类  
    113. for m=1:400  
    114.     output_fore(m)=find(outputOfTest(:,m)==max(outputOfTest(:,m)));  
    115. end  
    116.   
    117. %BP网络预测误差  
    118. error=output_fore-output1(n(1601:2000))';  
    119.   
    120. k=zeros(1,4);    
    121. %找出判断错误的分类属于哪一类  
    122. for i=1:400  
    123.     if error(i)~=0  
    124.         [b,c]=max(testOutput(i,:));  
    125.         switch c  
    126.             case 1   
    127.                 k(1)=k(1)+1;  
    128.             case 2   
    129.                 k(2)=k(2)+1;  
    130.             case 3   
    131.                 k(3)=k(3)+1;  
    132.             case 4   
    133.                 k(4)=k(4)+1;  
    134.         end  
    135.     end  
    136. end  
    137.   
    138. %找出每类的个体和  
    139. kk=zeros(1,4);  
    140. for i=1:400  
    141.     [b,c]=max(testOutput(i,:));  
    142.     switch c  
    143.         case 1  
    144.             kk(1)=kk(1)+1;  
    145.         case 2  
    146.             kk(2)=kk(2)+1;  
    147.         case 3  
    148.             kk(3)=kk(3)+1;  
    149.         case 4  
    150.             kk(4)=kk(4)+1;  
    151.     end  
    152. end  
    153.   
    154. %正确率  
    155. rightridio=(kk-k)./kk  

    激活函数
    [plain] view plaincopy在CODE上查看代码片派生到我的代码片
     
    1. %% 激活函数  
    2. function [ y ] = g( x )  
    3.     y = 1./(1+exp(-x));  
    4. end  
  • 相关阅读:
    实验证明:ObjectiveC++ 完美支持 ARC
    用 Java 实现的日志切割清理工具
    数字电视,方便了谁
    商品EAN13条码的生成
    关于错误“Cannot connect to the Citrix MetaFrame server.Can't assign requested address”的解决方法
    "加载类型库/dll时出错" 的解决方法
    解决连接SQL Server 2000的TCP/IP错误的Bug
    电脑自动关机之CPU风扇烧坏
    winrar 8 注册方法
    电脑死机之CPU温度过高
  • 原文地址:https://www.cnblogs.com/lianjiehere/p/4618525.html
Copyright © 2011-2022 走看看