zoukankan      html  css  js  c++  java
  • 卷积神经网络通俗解读

    转载自:https://blog.csdn.net/dong_lxkm/article/details/80575207

    一、前言

        最近一直在研究深度学习,联想起之前所学,感叹数学是一门朴素而神奇的科学。F=G*m1*m2/r²万有引力描述了宇宙星河运转的规律,E=mc²描述了恒星发光的奥秘,V=H*d哈勃定律描述了宇宙膨胀的奥秘,自然界的大部分现象和规律都可以用数学函数来描述,也就是可以求得一个函数。

        神经网络(《简单又复杂的人工神经网络》)可以逼近任何连续的函数,那么神经网络就有无限的泛化能力。对于大部分分类问题而言,本质就是求得一个函数y=f(x),例如:对于图像识别而言就是求得一个以像素张量为自变量的函数y=F(像素张量),其中y=猫、狗、花、汽车等等;对于文本情感分析而言,就是为了求得一个以词向量或者段落向量为自变量的函数y=F(词向量),其中y=正面、负面等等……

    二、导读

        本篇博客包括以下内容:

        1、卷积神经网络的原理

        2、基于dl4j定型一个卷积神经网络来进行手写数字识别

    三、卷积神经网络原理

        下面左边有个9*9的网格,红色填充的部分构成了数字7,把红色部分填上1,空白部分填上0,就构成了一个二维矩阵,传统做法可以用求向量距离,如果数字全部都标准的写在网格中相同的位置,那么肯定是准确的,但是,实际上数字7在书写的过程中,可能偏左一点、偏右一点,变形扭曲一点,这时候就难以识别。另外,一幅图片的像素点的数量是巨大的,例如一幅50*50的图片将有2500个像素点,每个像素点有R、G、B三个维度的颜色,那么输入参数的个数有7500个,这个运算量是巨大的。

        215936_Pa7H_1778239.png              220327_n8Z6_1778239.png

        那么就需要有一个抽象特征、降低数据维度的方法,这就说到了卷积运算,用一个小于图片的卷积核扫过整幅图片求点积。卷积的过程看下图。图片来源于https://my.oschina.net/u/876354/blog/1620906

        002928_hnHI_876354.gif

        卷积运算的过程在于寻找图片中的显著特征,并达到降维的目的,整个过程相当于一个函数扫过另一个函数,扫过时两个函数的积分重叠部分并没改变图片的特征形状,并可以降低维度,另外还可以分区块来提取特征,并且拼接特征。

    convgaus

        为了进一步降低维度,引入了池化,池化的方式有很多,如最大值,平均值。下图展示了一个步长为2的2*2最大池化过程,用一个2*2的方块扫描过,求Max,总共扫描4次,4次扫描的最大值分别是6、8、3、4。

    maxpool

        最后,经过多层卷积和池化之后,会得到一个矩阵,该矩阵作为一个全连接网络的输入,在逼近一个函数,就识别出数字了,以上图得到的6、8、3、4为例,全连接网络求一个函数。

    231633_dBzj_1778239.png

    四、deeplearning4j手写体识别

        1、先下载mnist数据集,地址如下:

           http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz

        2、解压(我解压在E盘)

        3、训练网络,评估(一些比较难的部分都做了注释)

    1.  
      public class MnistClassifier {
    2.  
       
    3.  
      private static final Logger log = LoggerFactory.getLogger(MnistClassifier.class);
    4.  
      private static final String basePath = "E:";
    5.  
       
    6.  
      public static void main(String[] args) throws Exception {
    7.  
      int height = 28;
    8.  
      int width = 28;
    9.  
      int channels = 1; // 这里有没有复杂的识别,没有分成红绿蓝三个通道
    10.  
      int outputNum = 10; // 有十个数字,所以输出为10
    11.  
      int batchSize = 54;//每次迭代取54张小批量来训练,可以查阅神经网络的mini batch相关优化,也就是小批量求平均梯度
    12.  
      int nEpochs = 1;//整个样本集只训练一次
    13.  
      int iterations = 1;
    14.  
       
    15.  
      int seed = 1234;
    16.  
      Random randNumGen = new Random(seed);
    17.  
       
    18.  
      File trainData = new File(basePath + "/mnist_png/training");
    19.  
      FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
    20.  
      ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); //以父级目录名作为分类的标签名
    21.  
      ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);//构造图片读取类
    22.  
      trainRR.initialize(trainSplit);
    23.  
      DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
    24.  
       
    25.  
      // 把像素值区间 0-255 压缩到0-1 区间
    26.  
      DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
    27.  
      scaler.fit(trainIter);
    28.  
      trainIter.setPreProcessor(scaler);
    29.  
       
    30.  
       
    31.  
      // 向量化测试集
    32.  
      File testData = new File(basePath + "/mnist_png/testing");
    33.  
      FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
    34.  
      ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
    35.  
      testRR.initialize(testSplit);
    36.  
      DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
    37.  
      testIter.setPreProcessor(scaler); // same normalization for better results
    38.  
       
    39.  
      log.info("Network configuration and training...");
    40.  
      Map<Integer, Double> lrSchedule = new HashMap<>();//设定动态改变学习速率的策略,key表示小批量迭代到几次
    41.  
      lrSchedule.put(0, 0.06);
    42.  
      lrSchedule.put(200, 0.05);
    43.  
      lrSchedule.put(600, 0.028);
    44.  
      lrSchedule.put(800, 0.0060);
    45.  
      lrSchedule.put(1000, 0.001);
    46.  
       
    47.  
      MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    48.  
      .seed(seed)
    49.  
      .iterations(iterations)
    50.  
      .regularization(true).l2(0.0005)
    51.  
      .learningRate(.01)
    52.  
      .learningRateDecayPolicy(LearningRatePolicy.Schedule)
    53.  
      .learningRateSchedule(lrSchedule)
    54.  
      .weightInit(WeightInit.XAVIER)
    55.  
      .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
    56.  
      .updater(Updater.NESTEROVS)
    57.  
      .list()
    58.  
      .layer(0, new ConvolutionLayer.Builder(5, 5)
    59.  
      .nIn(channels)
    60.  
      .stride(1, 1)
    61.  
      .nOut(20)
    62.  
      .activation(Activation.IDENTITY)
    63.  
      .build())
    64.  
      .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
    65.  
      .kernelSize(2, 2)
    66.  
      .stride(2, 2)
    67.  
      .build())
    68.  
      .layer(2, new ConvolutionLayer.Builder(5, 5)
    69.  
      .stride(1, 1)
    70.  
      .nOut(50)
    71.  
      .activation(Activation.IDENTITY)
    72.  
      .build())
    73.  
      .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
    74.  
      .kernelSize(2, 2)
    75.  
      .stride(2, 2)
    76.  
      .build())
    77.  
      .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
    78.  
      .nOut(500).build())
    79.  
      .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
    80.  
      .nOut(outputNum)
    81.  
      .activation(Activation.SOFTMAX)
    82.  
      .build())
    83.  
      .setInputType(InputType.convolutionalFlat(28, 28, 1))
    84.  
      .backprop(true).pretrain(false).build();
    85.  
       
    86.  
      MultiLayerNetwork net = new MultiLayerNetwork(conf);
    87.  
      net.init();
    88.  
      net.setListeners(new ScoreIterationListener(10));
    89.  
      log.debug("Total num of params: {}", net.numParams());
    90.  
       
    91.  
      // 评估测试集
    92.  
      for (int i = 0; i < nEpochs; i++) {
    93.  
      net.fit(trainIter);
    94.  
      Evaluation eval = net.evaluate(testIter);
    95.  
      log.info(eval.stats());
    96.  
      trainIter.reset();
    97.  
      testIter.reset();
    98.  
      }
    99.  
      ModelSerializer.writeModel(net, new File(basePath + "/minist-model.zip"), true);//保存训练好的网络
    100.  
      }
    101.  
      }
    1. 运行main方法,得到如下评估结果:

     # of classes:    10
     Accuracy:        0.9897
     Precision:       0.9897
     Recall:          0.9897
     F1 Score:        0.9896

        整个效果还比较好,保存好训练的网络,便可以用于手写体数据的识别了,下一篇博客将介绍怎么加载定型的网络,配合springMVC来开发一个手写体识别的应用。

  • 相关阅读:
    页面加载完没有其他操作的情况下直接获取音频时长为NAN问题
    关于mysql的一些操作
    阿里云服务器登录不上 提示:之前用于连接到 (公网ip) 的凭据无法工作(1核1G) 以及阿里云新版本安全组策略没有开启80端口导致网站只能ping通 访问不到的问题
    微信浏览器禁止页面下拉查看网址(不影响页面内部scroll)
    2018年11月17号第一次参加源创会记录
    使用了eclipse10年之后,我终于投向了IDEA
    spring/spring boot/spring cloud书籍推荐
    python数据库连接例子
    Spring Cloud Eureka配置文件例子与较为详细说明
    spring源代码下载并导入eclipse技巧
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11174567.html
Copyright © 2011-2022 走看看