zoukankan      html  css  js  c++  java
  • 简洁的BP及RBF神经网络代码

    BP神经网络

    function [W,err]=BPTrain(data,label,hiddenlayers,nodes,type)
    %Train the bp artial nueral net work
    %input data,label,layers,nodes,type
    %data:dim*n
    %label:1*n
    %layers:m:number of hidden layers
    %nodes:num_1;num_2...num_m
    %type==1:create and train
    %type==0:train
    %tanh / 双曲正切: tanh(x) = sinh(x) / cosh(x)=[e^x - e^(-x)] / [e^x + e^(-x)]
    %(tanh(x))'=sech^2(x)
    %sech / 双曲正割: sech(x) = 1 / cosh(x) = 2 / [e^x + e^(-x)]
    if type==1
       %create the nureal network and train
       nodes=[size(data,1);nodes];
       nodes=[nodes+1;size(label,1)];
       %W{1}=random(,nodes(1));
       layers=hiddenlayers+2;
       for i=1:layers-2
           W{i}=rand(nodes(i),nodes(i+1)-1);
       end
       W{layers-1}=rand(nodes(layers-1),nodes(layers));
    else
        %do nothing
    end
    %train the bp network
    %the termination condition
    %iteration.error
    iter=0;
    error=inf;
    maxiter=2000;
    lr=0.1;
    epision=0.1;
    tic
    while iter<maxiter&&error>epision
        iter=iter+1;
        error=0;
        for k=1:size(data,2)
            %forward process
            y{1}=[data(:,k)];
            v{1}=y{1};
            for i=1:layers-1
                y{i}=[1;y{i}]; 
                v{i+1}=W{i}'*y{i};
                y{i+1}=tanh(v{i+1});
            end
            %back process
            error=error+abs(label(k)-y{layers});
            delta=(label(k)-y{layers}).*((sech(v{layers}).^2));
            W{layers-1}=W{layers-1}+lr.*(y{layers-1}*delta);
            for i=layers-1:-1:2     
                delta=sech(v{i}).^2.*(W{i}(1:size(W{i},1)-1,:)*delta);     
                W{i-1}=W{i-1}+lr.*(y{i-1}*delta');
            end
        end 
        err(iter)=error;
        error
    end
    toc
    
    測试代码

    function res=BPTest(W,data)
    for k=1:size(data,2)
       y=data(:,k); 
       for i=1:length(W)-1
           y=[1;y];
           y=tanh((W{i}'*y));
       end
       res(k)=tanh(W{i+1}'*[1;y]);
    end

    global rbf_sigma;
    global rbf_center;
    global rbf_weight;
    if strcmp(traintype,'data')
        traindist=pdist2(traindata,traindata);
        rbf_sigma=max(max(traindist))/(scale.^2);%/(2*sqrt(sqrt(length(traindata))));
        rbf_center=traindata;
        Phi=exp(-traindist./rbf_sigma);
        rbf_weight=inv(Phi)*trainlabel;
        
    else if strcmp(traintype,'cluster')
            [Idx,C,sumD,D]=kmeans(traindata,K,'emptyaction','singleton');
            traindist=pdist2(traindata,C);
            Cdist=pdist2(C,C);
            rbf_sigma=max(max(Cdist))/(scale.^2);%/(2*sqrt(sqrt(length(traindata))));
            rbf_center=C;
            Phi=exp(-traindist./rbf_sigma);
            rbf_weight=inv(Phi'*Phi)*Phi'*trainlabel;
        else if strcmp(traintype,'descend')
                
                
            end
            
        end
    end
    測试 代码

    function predcict=RBFTest(data)
    
    global rbf_sigma;
    global rbf_center;
    global rbf_weight;
    
    testdist=pdist2(data,rbf_center);
    
    predcict=exp(-testdist./(2*rbf_sigma))*rbf_weight;






  • 相关阅读:
    WPF 命令基础
    委托 C#
    Volley网络请求框架的基本用法
    MailOtto 实现完美预加载以及源码解读
    Android_时间服务
    Android_Chronometer计时器
    Android_Json实例
    完结篇
    就快完结篇
    MySQL 选出日期时间最大的一条记录
  • 原文地址:https://www.cnblogs.com/yangykaifa/p/7243667.html
Copyright © 2011-2022 走看看