zoukankan      html  css  js  c++  java
  • 【Deep Learning】BP网络手写识别

    手写数字识别

    1. BP神经网络

    1.1 BP算法原理


    BP即反向传播算法。利用输出后的误差计算上一层的误差,以此类推,直到输入层,然后再动态调节各层之间的连接权值来减小误差。

    1.2 三层网络结构

    1.2 Sigmoid激活函数

    Sigmoid : f(x)=1/(1+e^−x)

    f(x)的取值范围为0~1;

    1.3 反向误差的计算及权值调整

    BP神经网络的数学原理及其算法实现

    2. 基于BP的手写数字识别

    2.1 格式化输入

    选取Mnist数据集中的0~9手写数字图片各28张。

    将图像矩阵转化为BP网络可接受的一维数组形式:

    //获取像素数组
        private double[] getImagePixel(String image) throws Exception {
            File file = new File(image);
            BufferedImage bi = null;
            try {
                bi = ImageIO.read(file);
            } catch (Exception e) {
                e.printStackTrace();
            }
            int width = bi.getWidth();
            int height = bi.getHeight();
            double[] vector = new double[width / 2 * height / 2];
            int k = 0;
            for (int i = 0; i < width; i += 2) {
                for (int j = 0; j < height; j += 2) {
                    int whiteNum = 0;
                    for (int m = 0; m < 2; m++) {
                        for (int n = 0; n < 2; n++) {
                            if (isWhite(bi, i + m, j + n)) {
                                whiteNum++;
                            }
                        }
                    }
                    vector[k++] = whiteNum / 4;
                }
            }
            // System.out.println(Arrays.toString(vector));
            return vector;
        }
    

    为了减小运算量,选取一个2×2的矩阵区域,统计区域

    内白色像素点的个数,将数据量缩小为原来的1/4。

    //二值化并归一化数据
        private boolean isWhite(BufferedImage bi, int i, int j) {
            int pixel = bi.getRGB(i, j);
            int[] rgb = new int[3];
            rgb[0] = (pixel & 0xff0000) >> 16;
            rgb[1] = (pixel & 0xff00) >> 8;
            rgb[2] = (pixel & 0xff);
            double d = (double) ((rgb[0] * 38 + rgb[1] * 75 + rgb[2] * 15) >> 7) > 100 ? 1 : 0;
            if (d == 1) {
                return true;
            } else
                return false;
        }
    

    2.创建神经网络并训练网络

    对网络进行20000次迭代训练或直到误差小于0.001为止

    String trianDataPath = "C:\Users\Administrator\Desktop\or_perceptron.nnet";
            String imgPath = "data/train2/";
            int maxLearn = 50000;
            double maxError = 0.0001;
            Stack<Long> stack = new Stack<>();
            stack.push(System.currentTimeMillis());
            System.out.println("->初始化多层网络...");
            MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 100, 20, 10);
            BackPropagation bp = new BackPropagation();
    
            bp.setMaxIterations(maxLearn);
            bp.setMaxError(maxError);
            bp.setLearningRate(0.5);
            bp.setMinErrorChange(0.0001);
            //bp.setBatchMode(true);
            // bp.setMinErrorChangeIterationsLimit(1000);
            myMlPerceptron.setLearningRule(bp);
            System.out.println("->初始化完成");
            LearningRule learningRule = myMlPerceptron.getLearningRule();
            learningRule.addListener(new LearningEventListener() {
                @Override
                public void handleLearningEvent(LearningEvent event) {
                    BackPropagation bp = (BackPropagation) event.getSource();
                    int iteration = bp.getCurrentIteration();
                    if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED && iteration % 100 == 0) {
                        System.out.print("->学习次数: " + iteration + ",当前误差: " + bp.getTotalNetworkError());
                        System.out.println(",用时:"+(System.currentTimeMillis()-stack.pop())+" ms");
                        stack.push(System.currentTimeMillis());
                    }
                }
            });
            System.out.println("->创建数据集...");
            DataSet trainingSet = new DataSet(100, 10);
            // 添加训练数据到数据集
            File f = new File(imgPath);
            File[] list = f.listFiles();
            ImageToVector itv = new ImageToVector();
            for (int i = 0; i < list.length; i++) {
                String fileName = list[i].getName();
                double[] input = itv.imageToVector(imgPath + fileName, 20, 20);
                trainingSet.addRow(input, getTarget(getNumber(fileName)));
            }
            System.out.println("->数据集创建完成");
            System.out.println("->开始学习...");
            myMlPerceptron.learn(trainingSet);
            System.out.println("->学习完成");
            System.out.println("->保存学习数据...");
            myMlPerceptron.save(trianDataPath);
            System.out.println("->保存完成...");

    2.3 测试网络

    测试共选取4990张图片,识别正确2646张,正确率:0.530

    2605。识别率较低。

  • 相关阅读:
    redis的安装
    thinkphp5学习
    php数组排序和查找的算法
    phprpc的简单使用
    apache学习教程
    mysql的存储过程,函数,事件,权限,触发器,事务,锁,视图,导入导出
    php设计模式八-----装饰器模式
    php设计模式七 ---组合模式
    64bit ubuntu14.04编译PlatinumKit出现的arm-linux-androideabi-g++: not found错误解决方法
    TS相关知识点
  • 原文地址:https://www.cnblogs.com/cnsec/p/13286769.html
Copyright © 2011-2022 走看看