zoukankan      html  css  js  c++  java
  • 神经网络原理和实现(MATLAB)

     

    main.h

    clear;
    clc;
    load mnist_uint8;
    
    train_x = double(train_x')/255;
    test_x = double(test_x')/255;
    train_y = double(train_y');
    test_y = double(test_y');
    %每层的神经元数目,头和尾分别表示input和output
    net = nnsetup([784 200 10]);
    
    opts.alpha = 1;
    opts.batchsize = 50;
    opts.numepochs = 20;
    
    net = nntrain(net, train_x, train_y, opts);
    plot(net.rL);
    
    err = nnpredict(net, test_x, test_y);
    fprintf('err is %.2f
    ', err*100);


    nnsetup.m

    % 初始化权值
    function net = nnsetup(layers) net.size = layers; net.layers = numel(layers); net.inputs = layers(1); net.outputs = layers(net.layers); for i=1:net.layers-1 net.layer{i}.w = (rand(layers(i+1), layers(i)+1)-0.5)/ layers(i); end end

    nnff.m

    %前馈传导计算
    function net = nnff(net, x)
        m = size(x, 2);
        net.layer{1}.a = x;
        
        for i=2:net.layers
            net.layer{i}.a = sigm(net.layer{i-1}.w * [net.layer{i-1}.a;ones(1, m)]);
        end
    end

    nnbp.m

    %反向传导   
    function net = nnbp(net, y)
        m = size(y, 2);
        
        net.e = net.layer{net.layers}.a-y;
        %  loss function
        net.L = 1/2* sum(net.e(:) .^ 2) / size(net.e, 2);
        
        %output layer
        net.layer{net.layers}.deta = net.e.* net.layer{net.layers}.a .*(1-net.layer{net.layers}.a);    
        %update
        net.layer{net.layers-1}.w = net.layer{net.layers-1}.w - net.layer{net.layers}.deta * [net.layer{net.layers-1}.a;ones(1, m)]' ./m;
        
        %hide layer 
        for k=net.layers-1:-1:2
            net.layer{k}.deta =net.layer{k}.w(:, 1:net.size(k))'*net.layer{k+1}.deta .* net.layer{k}.a.*(1-net.layer{k}.a);
            net.layer{k-1}.w = net.layer{k-1}.w - net.layer{k}.deta * [net.layer{k-1}.a; ones(1, m)]'./m; 
        end  
    end    

    nnpredict.m

    function er =  nnpredict(net, test_x, test_y)
        net = nnff(net, test_x);
        o = net.layer{net.layers}.a;
        [~, h] = max(o);
        [~, a] = max(test_y);
        bad = find(h ~= a);
    
        er = numel(bad) / size(test_y, 2);
    end
  • 相关阅读:
    mysql5.7慢查询开启配置
    easyui的datagrid删除一条记录后更新出问题
    easyui跨iframe属性datagrid
    struts2笔记12-声明式异常
    struts2笔记11-OGNL
    struts2笔记10-值栈
    linux命令学习03-grep
    struts2笔记09-动态方法调用
    1、GIT简介
    玩转Python语言之4:奇技淫巧
  • 原文地址:https://www.cnblogs.com/linyuanzhou/p/4876644.html
Copyright © 2011-2022 走看看