zoukankan      html  css  js  c++  java
  • 神经网络入门 第6章 识别手写字体

     前言

        神经网络是一种很特别的解决问题的方法。本书将用最简单易懂的方式与读者一起从最简单开始,一步一步深入了解神经网络的基础算法。本书将尽量避开让人望而生畏的名词和数学概念,通过构造可以运行的Java程序来实践相关算法。

        关注微信号“逻辑编程"来获取本书的更多信息。

        这一章节我们将会解决一个真正的问题:手写字体识别。我们将识别像下面图中这样的手写数字。

     

        在开始之前,我们先要准备好相应的测试数据。我们不能像前边那样简单的产生手写字体,毕竟我们自己还不知道如何写出一个产生手写字体的算法。训练要达到一定的精度需要较多的训练数据。还好,前人栽树后人乘凉,先驱们已经收集了宝贵的训练材料。MNIST就是一个广泛使用的数据集。不但可以拿来用,我们还可以从网站上看到别人的识别准确率。这样我们就有了很好的参照。MNIST包含一套训练数据和一套测试数据,分别来自不同的人群的手写。

        MNIST网站: http://yann.lecun.com/exdb/mnist/

    这个数据集是写在特定的二进制文件中的,并非普通图片格式。每个图片数据由28*28个像素组成。每个像素1个字节表示颜色灰度级。MNIST网站上有具体的介绍。

        我们写一个类来完成数据集的读取工作,并提供接口返回指定的训练或者测试数据。具体代码不做分析,仅将代码附在下面,供读者使用。代码执行前要先下载数据文件并保留GZIP格式。代码执行后将随机抽取20个生成PNG图片供读者自己查看和验证数据内容。

        下面我们写个测试类来识别手写字体。我们使用MNIST库的60000训练数据来反复训练我们的神经网络。每轮训练后使用MNIST库的10000个测试数据来测试识别率。

        下面是代码:

    
    
    package com.luoxq.ann;

    import java.util.Arrays;
    import java.util.Random;

    public class MnistTest {


    public static void main(String... args) {
    int[] shape = {28 * 28, 10};
    NeuralNetwork nn = new NeuralNetwork(shape);
    Mnist mnist = new Mnist();
    mnist.load();
    mnist.shuffle();
    System.out.println("Shape: " + Arrays.toString(shape));
    System.out.println("Initial correct rate: " + test(nn, mnist));
    int epochs = 1000;
    double rate = 0.5;
    System.out.println("Learning rate: " + rate);
    System.out.println("Epoch,Time,Correctness ----------------------");
    long time = System.currentTimeMillis();
    Mnist.Data[] data = mnist.getTrainingSlice(0, 60000);
    for (int epoch = 1; epoch <= epochs; epoch++) {
    for (int sample = 0; sample < data.length; sample++) {
    nn.train(data[sample].input, data[sample].output, rate);
    }
    long seconds = (System.currentTimeMillis() - time) / 1000;
    System.out.println(epoch + ", " + seconds + ", " +
    test(nn, mnist));
    }
    }

    private static int test(NeuralNetwork nn, Mnist mnist) {
    int correct = 0;
    Mnist.Data[] data = mnist.getTestSlice(0, 10000);
    for (int sample = 0; sample < data.length; sample++) {
    if (max(nn.f(data[sample].input)) == data[sample].label) {
    correct++;
    }
    }
    return correct;
    }

    private static int max(double[] d) {
    double max = d[0];
    int idx = 0;
    for (int i = 1; i < d.length; i++) {
    if (max < d[i]) {
    max = d[i];
    idx = i;
    }
    }
    return idx;
    }
    }
    
    

        我们先用一个10个神经元的单层神经网络试试看。结果出乎意外的好。我们很快就获得了超过90%的正确率。单层网络几乎就是对每个数字的像素分布做简单统计。能获得如此高的识别率,还是很神奇的。 在达到90%之后再训练已经效果不大,达到饱和了。我们必须换一种方法来做了。 

    Shape: [784, 10]

    Initial correct rate: 1373

    Learning rate: 0.5

    Epoch,Time,Correctness

    ----------------------

    1, 4, 6429

    2, 8, 7663

    3, 13, 8963

    4, 17, 9029

    5, 22, 9016

    6, 27, 9062

    7, 31, 9063

    8, 36, 9066

    9, 41, 9072

    10, 45, 9057

    11, 50, 9084

    12, 55, 9072

    13, 61, 9062

    14, 66, 9050

    15, 70, 9077

    16, 75, 9052

    17, 79, 9068

    18, 84, 9055

    19, 88, 9060

    20, 93, 9064

        那么我们来使用三层神经网络试一试。在试了几个不同的中间层大小和学习率参数之后,我找到了下面这个较好的参数组合:

    Shape: [784, 50, 10]

    Initial correct rate: 944

    Learning rate: 1.0

    Epoch,Time,Correctness

    ----------------------

    1, 24, 7459

    2, 59, 9232

    3, 99, 9313

    4, 131, 9379

    5, 153, 9412

    6, 176, 9443

    7, 200, 9412

    8, 226, 9447

    9, 248, 9462

    10, 269, 9461

    11, 290, 9465

    12, 314, 9493

    13, 343, 9477

    14, 368, 9499

    15, 392, 9502

    16, 420, 9509

    17, 447, 9482

    18, 472, 9508

    19, 496, 9491

    20, 518, 9536

    21, 545, 9523

    22, 569, 9549

    23, 593, 9527

    24, 618, 9527

    25, 643, 9520

    26, 667, 9513

    27, 689, 9507

    28, 712, 9527

    29, 734, 9501

    30, 758, 9521

    31, 781, 9508

    32, 804, 9534

    33, 827, 9534

    34, 850, 9550

    35, 875, 9569

        我们很快达到了95%以上的正确率。可见多层网络相对单层神经网络还是有优势的。虽然这个正确率还达不到产品水平,但是作为初次尝试结果还是很不错的。

        下面是MNIST文件读取源代码:

    package com.luoxq.ann;

    import javax.imageio.ImageIO;
    import java.awt.image.BufferedImage;
    import java.io.DataInputStream;
    import java.io.File;
    import java.io.FileInputStream;
    import java.util.Random;
    import java.util.zip.GZIPInputStream;

    /**
    * Created by luoxq on 17/4/15.
    */
    public class Mnist {


    static class Data {
    public byte[] data;
    public int label;
    public double[] input;
    public double[] output;
    }

    public static void main(String... args) throws Exception {
    Mnist mnist = new Mnist();
    mnist.load();
    System.out.println("Data loaded.");
    Random rand = new Random(System.nanoTime());
    for (int i = 0; i < 20; i++) {
    int idx = rand.nextInt(60000);
    Data d = mnist.getTrainingData(idx);
    BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
    for (int x = 0; x < 28; x++) {
    for (int y = 0; y < 28; y++) {
    img.setRGB(x, y, toRgb(d.data[y * 28 + x]));
    }
    }
    File output = new File(i + "_" + d.label + ".png");
    if (!output.exists()) {
    output.createNewFile();
    }
    ImageIO.write(img, "png", output);
    }
    }

    static int toRgb(byte bb) {
    int b = (255 - (0xff & bb));
    return (b << 16 | b << 8 | b) & 0xffffff;
    }


    Data[] trainingSet;
    Data[] testSet;

    public void shuffle() {
    Random rand = new Random();
    for (int i = 0; i < trainingSet.length; i++) {
    int x = rand.nextInt(trainingSet.length);
    Data d = trainingSet[i];
    trainingSet[i] = trainingSet[x];
    trainingSet[x] = trainingSet[i];
    }
    }

    public Data getTrainingData(int idx) {
    return trainingSet[idx];
    }

    public Data[] getTrainingSlice(int start, int count) {
    Data[] ret = new Data[count];
    System.arraycopy(trainingSet, start, ret, 0, count);
    return ret;
    }

    public Data getTestData(int idx) {
    return testSet[idx];
    }

    public Data[] getTestSlice(int start, int count) {
    Data[] ret = new Data[count];
    System.arraycopy(testSet, start, ret, 0, count);
    return ret;
    }


    public void load() {
    trainingSet = load("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz");
    testSet = load("t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz");
    if (trainingSet.length != 60000 || testSet.length != 10000) {
    throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length);
    }
    }

    private Data[] load(String imgFile, String labelFile) {
    byte[][] images = loadImages(imgFile);
    byte[] labels = loadLabels(labelFile);
    if (images.length != labels.length) {
    throw new RuntimeException("Images and label doesn't match: " + imgFile + " " + labelFile);
    }
    int len = images.length;
    Data[] data = new Data[len];
    for (int i = 0; i < len; i++) {
    data[i] = new Data();
    data[i].data = images[i];
    data[i].label = 0xff & labels[i];
    data[i].input = dataToInput(images[i]);
    data[i].output = labelToOutput(labels[i]);
    }
    return data;
    }

    private double[] labelToOutput(byte label) {
    double[] o = new double[10];
    o[label] = 1;
    return o;
    }

    private double[] dataToInput(byte[] b) {
    double[] d = new double[b.length];
    for (int i = 0; i < b.length; i++) {
    d[i] = (b[i] & 0xff) / 255.0;
    }
    return d;
    }

    private byte[][] loadImages(String imgFile) {
    try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) {
    int magic = in.readInt();
    if (magic != 0x00000803) {
    throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
    }
    int count = in.readInt();
    int rows = in.readInt();
    int cols = in.readInt();
    if (rows != 28 || cols != 28) {
    throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols);
    }
    byte[][] data = new byte[count][rows * cols];
    for (int i = 0; i < count; i++) {
    in.readFully(data[i]);
    }
    return data;
    } catch (Exception ex) {
    throw new RuntimeException("Failed to read file: " + imgFile, ex);
    }
    }

    private byte[] loadLabels(String labelFile) {
    try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) {
    int magic = in.readInt();
    if (magic != 0x00000801) {
    throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
    }
    int count = in.readInt();
    byte[] data = new byte[count];
    in.readFully(data);
    return data;
    } catch (Exception ex) {
    throw new RuntimeException("Failed to read file: " + labelFile, ex);
    }
    }

    }

    欢迎关注订阅号逻辑编程阅读更多内容。

  • 相关阅读:
    Django学习:博客分类统计(14)
    Django学习:上下篇博客和按日期分类(13)
    Django学习:分页优化(12)
    Django学习:shell命令行模式以及分页(11)
    Django学习:博客页面的响应式布局(10)
    Django学习:响应式导航条(9)
    八、Django学习:使用css美化页面
    七、Django学习:模板嵌套
    js日期使用总结
    Vue 的数据劫持 + 发布订阅
  • 原文地址:https://www.cnblogs.com/javadaddy/p/6748664.html
Copyright © 2011-2022 走看看