zoukankan      html  css  js  c++  java
  • 机器学习: Softmax Classifier (三个隐含层)

    程序实现 softmax classifier, 含有三个隐含层的情况。activation function 是 ReLU : f(x)=max(0,x)

    f1=w1x+b1

    h1=max(0,f1)

    f2=w2h1+b2

    h2=max(0,f2)

    f3=w3h2+b3

    h3=max(0,f3)

    f4=w4h3+b4

    y=ef4ijef4j

    
    function Out=Softmax_Classifier_3(train_x,  train_y, opts)
    
    % activation function RELU. y=max(0, x);
    
    
    % setting learning parameters
    step_size=opts.step_size;
    reg=opts.reg;
    batchsize = opts.batchsize;
    numepochs = opts.numepochs;
    K=opts.class;
    h1=opts.hidden_1;
    h2=opts.hidden_2;
    h3=opts.hidden_3;
    
    D=size(train_x, 2);
    
    W1=0.01*randn(D, h1);
    b1=zeros(1, h1);
    W2=0.01*randn(h1,  h2);
    b2=zeros(1, h2);
    W3=0.01*randn(h2,  h3);
    b3=zeros(1, h3);
    W4=0.01*randn(h3, K);
    b4=zeros(1, K);
    
    loss(1 : numepochs)=0;
    
    num_examples=size(train_x, 1);
    numbatches = num_examples / batchsize;
    
    for epoch=1:numepochs
    
         kk = randperm(num_examples);
         loss(epoch)=0;
    
         tic;
    
          sprintf('epoch %d:  
    ' , epoch)
    
    
         for bat=1:numbatches
    
             batch_x = train_x(kk((bat - 1) * batchsize + 1 : bat * batchsize), :);
             batch_y = train_y(kk((bat - 1) * batchsize + 1 : bat * batchsize), :);
    
             %% forward
             f1=batch_x*W1+repmat(b1, batchsize, 1);
             hiddenval_1=max(0, f1);
             f2=hiddenval_1*W2+repmat(b2, batchsize, 1);
             hiddenval_2=max(0, f2);
             f3=hiddenval_2*W3+repmat(b3, batchsize, 1);
             hiddenval_3=max(0, f3);
             scores=hiddenval_3*W4+repmat(b4, batchsize, 1);
    
             %% the loss
             exp_scores=exp(scores);
             dd=repmat(sum(exp_scores, 2), 1, K);
             probs=exp_scores./dd;
             correct_logprobs=-log(sum(probs.*batch_y, 2));
             data_loss=sum(correct_logprobs)/batchsize;
             reg_loss=0.5*reg*sum(sum(W1.*W1))+0.5*reg*sum(sum(W2.*W2))+0.5*reg*sum(sum(W3.*W3))+0.5*reg*sum(sum(W4.*W4));
             loss(epoch) =loss(epoch)+ data_loss + reg_loss;
    
             %% back propagation
             % output layer
             dscores = probs-batch_y;
             dscores=dscores/batchsize;
             dW4=hiddenval_3'*dscores;
             db4=sum(dscores);
    
             % hidden layer 3
             dhiddenval_3=dscores*W4';
             mask=max(sign(hiddenval_3), 0);
             df_3=dhiddenval_3.*mask;
             dW3=hiddenval_2'*df_3;
             db3=sum(df_3);
    
             % hidden layer 2
             dhiddenval_2=df_3*W3';
             mask=max(sign(hiddenval_2), 0);
             df_2=dhiddenval_2.*mask;
             dW2=hiddenval_1'*df_2;
             db2=sum(df_2);
    
             % hidden layer 1
             dhiddenval_1=df_2*W2';
             mask=max(sign(hiddenval_1), 0);
             df_1=dhiddenval_1.*mask;
             dW1=batch_x'*df_1;
             db1=sum(df_1);
    
             %% update
             dW4=dW4+reg*W4;
             dW3=dW3+reg*W3;
             dW2=dW2+reg*W2;
             dW1=dW1+reg*W1;
    
             W4=W4-step_size*dW4;
             b4=b4-step_size*db4;
    
             W3=W3-step_size*dW3;
             b3=b3-step_size*db3;
    
             W2=W2-step_size*dW2;
             b2=b2-step_size*db2;
    
             W1=W1-step_size*dW1;
             b1=b1-step_size*db1;
    
         end
    
         loss(epoch)=loss(epoch)/numbatches;
    
         sprintf('training loss is  %f:  
    ', loss(epoch))
    
        toc;
    
    end
    
    Out.W1=W1;
    Out.W2=W2;
    Out.W3=W3;
    Out.W4=W4;
    
    Out.b1=b1;
    Out.b2=b2;
    Out.b3=b3;
    Out.b4=b4;
    
    Out.loss=loss;
    
    
  • 相关阅读:
    Notification的使用
    Spring面向切面之AOP深入探讨
    使用注解配置Spring框架自动代理通知
    回顾Spring框架
    Spring利器之包扫描器
    Spring 核心概念以及入门教程
    Struts 2之动态方法调用,不会的赶紧来
    Struts2之过滤器和拦截器的区别
    Struts 2开讲了!!!
    Mybatis开篇以及配置教程
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9412453.html
Copyright © 2011-2022 走看看