zoukankan      html  css  js  c++  java
  • 机器学习 Softmax classifier (一个隐含层)

    程序实现 softmax classifier, 含有一个隐含层的情况。activation function 是 ReLU : f(x)=max(0,x)
    f1=w1x+b1

    h1=max(0,f1)

    f2=w2h1+b2

    y=ef2ijef2j

    
    function Out=Softmax_Classifier_1(train_x,  train_y, opts)
    
    % setting learning parameters
    step_size=opts.step_size;
    reg=opts.reg;
    batchsize = opts.batchsize;
    numepochs = opts.numepochs;
    K=opts.class;
    h=opts.hidden;
    
    D=size(train_x, 2);
    W1=0.01*randn(D,h);
    b1=zeros(1,h);
    W2=0.01*randn(h, K);
    b2=zeros(1,K);
    
    
    loss(1 : numepochs)=0;
    
    num_examples=size(train_x, 1);
    numbatches = num_examples / batchsize;
    
    for epoch=1:numepochs
    
        kk = randperm(num_examples);
        loss(epoch)=0;
    
        % %      tic;
        % %
        % %       sprintf('epoch %d:  
    ' , epoch)
    
    
        for bat=1:numbatches
    
            batch_x = train_x(kk((bat - 1) * batchsize + 1 : bat * batchsize), :);
            batch_y = train_y(kk((bat - 1) * batchsize + 1 : bat * batchsize), :);
    
            %% forward
            f1=batch_x*W1+repmat(b1, batchsize, 1);
            hiddenval_1=max(0, f1);
            scores=hiddenval_1*W2+repmat(b2, batchsize, 1);
    
            %% the loss
            exp_scores=exp(scores);
            dd=repmat(sum(exp_scores, 2), 1, K);
            probs=exp_scores./dd;
            correct_logprobs=-log(sum(probs.*batch_y, 2));
            data_loss=sum(correct_logprobs)/batchsize;
            reg_loss=0.5*reg*sum(sum(W1.*W1))+0.5*reg*sum(sum(W2.*W2));
            loss(epoch) =loss(epoch)+ data_loss + reg_loss;
    
            %% back propagation
            dscores = probs-batch_y;
            dscores=dscores/batchsize;
            dW2=hiddenval_1'*dscores;
            db2=sum(dscores);
    
            dhiddenval_1=dscores*W2';
            mask=max(sign(hiddenval_1), 0);
            df_1=dhiddenval_1.*mask;
            dW1=batch_x'*df_1;
            db1=sum(df_1);
    
            %% update
            dW2=dW2+reg*W2;
            dW1=dW1+reg*W1;
    
            W1=W1-step_size*dW1;
            b1=b1-step_size*db1;
    
            W2=W2-step_size*dW2;
            b2=b2-step_size*db2;
    
        end
    
        loss(epoch)=loss(epoch)/numbatches;
    
        if (mod(epoch, 10)==0)
            sprintf('epoch: %d, training loss is  %f:  
    ', epoch, loss(epoch))
        end
    
        toc;
    
    end
    
    Out.W1=W1;
    Out.b1=b1;
    Out.b2=b2;
    Out.W2=W2;
    Out.loss=loss;
    
    end
    
    
  • 相关阅读:
    java 读取文件内容 方法
    Linux常见问题解答--如何修复“tar:Exiting with failure status due to previous errors”
    FTPbug
    linux shell 字符串操作(长度,查找,替换)详解
    mysqldump参数详细说明
    Win7下的内置FTP组件的设置详解
    FTPAPI
    Linux文件传输FTP详解
    linux 利用shell将当前时间写入文件
    IDEA下创建SpringBoot+MyBatis+MySql项目实现动态登录与注册功能
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9412463.html
Copyright © 2011-2022 走看看