zoukankan      html  css  js  c++  java
  • RBF 和 BP 神经网络的比较

    Contents

    I. 清空环境变量

    clear all
    clc
    

    II. 训练集/测试集产生

    1. 导入数据

    load spectra_data.mat
    

    2. 随机产生训练集和测试集

    temp = randperm(size(NIR,1));
    % 训练集——50个样本
    P_train = NIR(temp(1:50),:)';
    T_train = octane(temp(1:50),:)';
    % 测试集——10个样本
    P_test = NIR(temp(51:end),:)';
    T_test = octane(temp(51:end),:)';
    N = size(P_test,2);
    

    III. 数据归一化,BP 网络需要归一化处理

    [p_train, ps_input] = mapminmax(P_train,0,1);
    p_test = mapminmax('apply',P_test,ps_input);
    
    [t_train, ps_output] = mapminmax(T_train,0,1);
    

    IV. RBF/BP神经网络创建及仿真测试

    1. 创建网络

    net_rbf = newrbe(P_train,T_train,30);
    net_bp = newff(p_train, t_train, 9);
    

    2. 设置 BP 网络 训练参数, RBF 网络不需要设置参数,除了spread

    net_bp.trainParam.epochs = 1000;
    net_bp.trainParam.goal = 1e-3;
    net_bp.trainParam.lr = 0.01;
    

    3. BP网络 训练, RBF 不需要训练

    net_bp = train(net_bp,p_train,t_train);
    

    4. 仿真测试

    T_sim_rbf = sim(net_rbf,P_test);
    t_sim_bp = sim(net_bp,p_test);
    

    5. 数据反归一化

    T_sim_bp = mapminmax('reverse',t_sim_bp,ps_output);
    

    V. 性能评价

    1. 相对误差error

    error_rbf = abs(T_sim_rbf - T_test)./T_test;
    error_bp = abs(T_sim_bp - T_test)./T_test;
    

    2. 决定系数R^2

    R2_rbf = (N * sum(T_sim_rbf .* T_test) - sum(T_sim_rbf) * sum(T_test))^2 / ((N * sum(T_sim_rbf.^2) - (sum(T_sim_rbf))^2) * (N * sum((T_test).^2) - (sum(T_test))^2));
    R2_bp = (N * sum(T_sim_bp .* T_test) - sum(T_sim_bp) * sum(T_test))^2 / ((N * sum(T_sim_bp.^2) - (sum(T_sim_bp))^2) * (N * sum((T_test).^2) - (sum(T_test))^2));
    

    3. 结果对比

    result = [T_test' T_sim_rbf' T_sim_bp' error_rbf' error_bp']
    
    result =
    
       83.4000   83.8939   83.3325    0.0059    0.0008
       86.6000   86.8234   86.8816    0.0026    0.0033
       88.5500   88.6082   88.1839    0.0007    0.0041
       88.7000   88.6681   88.9107    0.0004    0.0024
       88.4500   88.1451   88.2298    0.0034    0.0025
       86.1000   86.2906   86.3526    0.0022    0.0029
       88.1000   88.2477   87.3687    0.0017    0.0083
       88.7000   88.6095   88.9943    0.0010    0.0033
       86.5000   86.4926   86.5093    0.0001    0.0001
       85.4000   85.7454   86.0366    0.0040    0.0075
    
    

    VI. 绘图

    figure
    plot(1:N,T_test,'b:*',1:N,T_sim_rbf,'r-o')
    legend('真实值','预测值')
    xlabel('预测样本')
    ylabel('辛烷值')
    string = {'基于RBF网络测试集辛烷值含量预测结果对比';['R^2=' num2str(R2_rbf)]};
    title(string)
    figure
    plot(1:N,T_test,'b:*',1:N,T_sim_bp,'r-o')
    legend('真实值','预测值')
    xlabel('预测样本')
    ylabel('辛烷值')
    string = {'基于BP网络测试集辛烷值含量预测结果对比';['R^2=' num2str(R2_bp)]};
    title(string)
    

     

  • 相关阅读:
    第七周CorelDRAW课总结
    第七周CorelDRAW课总结
    hive基本操作与应用
    hive基本操作与应用
    hive基本操作与应用
    hive基本操作与应用
    Linux运维面试题:请简要说明Linux系统在目标板上的启动过程?
    Linux运维面试题:请简要说明Linux系统在目标板上的启动过程?
    arcserver开发小结(一)
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
  • 原文地址:https://www.cnblogs.com/momo072994MLIA/p/9494167.html
Copyright © 2011-2022 走看看