录:
1神经网络模型
2公式推导
3程序实现
1神经网络模型:
假设:输入为xi,m维,输出为y,1维,隐含层神经元n维,输入为xj,输出为xj',训练的元素组数为kk组。
数学描述如下:
目标函数:
(编辑器不支持公式好麻烦,玛德,手推拍照吧)
3程序实现:
根据上述原理,编写matlab程序,输入为3维,中间神经元9个,输出为一个;
数据采用《matlab在数学建模中的应用的》卓金武的数据:
通过人数、机动车数量、公路面积,来预测公路客运量:
代码如下:
clc clear tic SamNum=20; HiddenNum=9; InDim=3; OutDim=1; a=[20.55 22.44 25.37 27.13 29.45 30.10 30.96 34.06 36.42 38.09 39.13 39.99 ... 41.93 44.59 47.30 52.89 55.73 56.76 59.17 60.63]; b=[0.6 0.75 0.85 0.9 1.05 1.35 1.45 1.6 1.7 1.85 2.15 2.2 2.25 2.35 2.5 2.6... 2.7 2.85 2.95 3.1]; c=[0.09 0.11 0.11 0.14 0.20 0.23 0.23 0.32 0.32 0.34 0.36 0.36 0.38 0.49 ... 0.56 0.59 0.59 0.67 0.69 0.79]; d=[5126 6217 7730 9145 10460 11387 12353 15750 18304 19836 21024 19490 20433 ... 22598 25107 33442 36836 40548 42927 43462]; p=[a;b;c]; t=[d]; [SamIn,minp,maxp,tn,mint,maxt]=premnmx(p,t); NoiseVar=0.01; Noise=NoiseVar*randn(1,SamNum); SamOut=tn + Noise; SamIn=SamIn'; SamOut=SamOut'; MaxEpochs=50000; lr=0.045; E0=0.65*10^(-5); W1=0.5*rand(InDim,HiddenNum)-0.1; B1=0.5*rand(HiddenNum,1)-0.1; W2=0.5*rand(HiddenNum,OutDim)-0.1; B2=0.5*rand(OutDim,1)-0.1; for i=1:MaxEpochs %% Hiddenout=logsig(SamIn*W1+repmat(B1',SamNum,1)); Networkout=Hiddenout*W2+repmat(B2',SamNum,1); Error=Networkout-SamOut; SSE=sumsqr(Error) ErrHistory=[ SSE]; if SSE<E0,break, end dB2=zeros(OutDim,1); dW2=zeros(HiddenNum,OutDim); for jj=1:HiddenNum for k=1:SamNum dW2(jj,OutDim)=dW2(jj,OutDim)+Error(k)*Hiddenout(k,jj); end end for k=1:SamNum dB2(OutDim,1)=dB2(OutDim,1)+Error(k); end dW1=zeros(InDim,HiddenNum); dB1=zeros(HiddenNum,1); for ii=1:InDim for jj=1:HiddenNum for k=1:SamNum dW1(ii,jj)=dW1(ii,jj)+Error(k)*W2(jj,OutDim)*Hiddenout(k,jj)*(1-Hiddenout(k,jj))*(SamIn(k,ii)); dB1(jj,1)=dB1(jj,1)+Error(k)*W2(jj,OutDim)*Hiddenout(k,jj)*(1-Hiddenout(k,jj)); end end end W2=W2-lr*dW2; B2=B2-lr*dB2; W1=W1-lr*dW1; B1=B1-lr*dB1; end Hiddenout=logsig(SamIn*W1+repmat(B1',SamNum,1)); Networkout=Hiddenout*W2+repmat(B2',SamNum,1); aa=postmnmx(Networkout,mint,maxt); x=1990:2009; newk=aa; figure plot(x,d,'r-o',x,newk,'b--+') legend('原始数据','训练后的数据'); xlabel('年份');ylabel('客运量/万人'); toc
注:程序在电脑(i5,8G,win7,64)运行时间16s左右。