zoukankan      html  css  js  c++  java
  • CNN图像识别的经典模型简述

    0. 雁字无多

        我这两天要是再不学习一下之后可能时间就很少了,期末作业是做了一些了。但是,今天接到一个大任务,今年实验室招标和项目我负责。就在我写这段文字的时候,我老板进来了。正好,我就继续写吧。可是又说了半天招标的事情,说XX所好赚钱好赚钱,流程怎么怎么跑,要去招标局和公司做啥做啥。

        这周换了一本书开始看,觉得没有上一本写的好。现在感觉这些书对我帮助不是很大了,书其实没有一些CSDN上的精品博客写的好。今天想写一下CNN的经典网络吧,做图像识别分类的。要具体的讲,自2012年开始,基于CNN用于打ILSVRC图像识别大赛的几大经典网络分别是:AlexNet,VGGNets,GoogLeNet&Inception和ResNets。这几大具有统治力的模型也是每年一更新不断刷新识别率与准确率。具体每个网络的结构、异同点不妨CSDN去看看,可以根据实际需求选择一个。

        下面我也就个人的理解对这几大网络做一个背景的简要描述。

        AlexNet 在2012年被提出是被认为开创了深度学习的时代,大数据、GPU、ReLu函数(加快收敛速度)和dropout(防止过拟合)等技术也都是为它的出现奠定了基础。它包含5层卷积层和3层全连接层如图1,这是网上的一张经典图形。因为ILSVRC是个千分类问题,所以全连接层最后的softmax输出为1000维向量。不过现在实际应用中现在很难看到AlexNet,毕竟它只是作为LeNet网络的一个历史突破性版本,它具有历史里程意义,但技术价值已经不如后面的新贵们了。

    图1 AlexNet

        VGGNet 是在AlexNet基础之上提出的,网络结构相似。不同点在于,VGGNet是在每一层卷积层上连续卷积2~4次,结构上与AlexNet相比并没有做什么改变,在此不再赘述。VGGNet对深度学习最大的贡献莫过于:不考虑其他因素(计算存储等),CNN网络的深度不断加深(增加卷积次数,通俗讲就是使计算复杂化)可以提升准确率。这也成为了现在做机器学习提升效果的核心方法之一:加深网络、数据增强模型融合。另外调参也很重要。虽然VGGNet参数很多,但由于设计上的很多细节,它能够很快收敛,具体原因我需要提一下吗....好吧,主要是卷积核大小比较小。最后还需要说一点,无脑的加深网络深度是不能得到效果的线性提升,这是因为参数变多之后误差函数的梯度弥散问题,导致网络在训练时无法收敛(无论从顺着哪个参数的梯度进行下降都无法使误差减小,因为梯度求导几乎为0)。比如VGGNet在超过20层后就效果就会下降。这个梯度弥散问题在后续的ResNets得到了一定程度的解决。

        GoogLeNet 与VGGNet在相比,在内存和计算消耗方面有非常大的优势。AlexNet增多卷积次数,不可避免使其增多了权值参数个数,它共有6000万个权值参数,是AlexNet的三倍以上;而GoogleLeNet只有500万个参数,所以在内存较小的移动端GoogLeNet有着更广泛的应用。当时提出GoogLeNet,就是因为Google的学者们虽然同意像AlexNet那样加深网络深度可以提高识别准确率,但如何更加有效的加深网络深度,使其能够在保证准确率提升的前提下尽可能的“轻量”,所以开启了名为Inception的项目工程。GoogLeNet就是InceptionV1版本,现在已经发布到InceptionV4版本——结合ResNets的GoogLeNet。具体的发展与区别可以百度了解一下。结构图是实在很复杂如图2。

    图2 GoogLeNet

        ResNets 是2016年微软亚研院的何凯明博士及其团队提出的深度残差网络。其最大的亮点就是通过设计的残差网络结构,避免了随着随着网络层数加深而产生的梯度消失或梯度爆炸的问题(该问题被称为深度网络的退化问题,degradation problem),不但能使深度神经网络的收敛速度更快、精度更高,而且让加深网络深度来提高网络效果成为可能。在学习的时候,很多博客和教材会把ResNets的网络结构同VGG进行对比学习,如图3。

     

    图3 ReNets

    1. 写得相思几许

        这几大网络github上有很好的项目工程,python3.6+tensorflow1.7亲测可用。贴一个链接吧:https://github.com/MachineLP/train_arch

        这一小节还是得弄一点干货吧哈哈哈,那就上一个我自己调通了把数据集下载这些都合成在一个py文件的GoogLeNet吧。InceptionV3,代码参考上述链接。

        

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    import argparse
    import os.path
    import re
    import sys
    import tarfile
    import numpy as np
    from six.moves import urllib
    import tensorflow as tf


    FLAGS = None
    DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'

    # NodeLookup类负责将分类器输出的类别编号与人类可读的标签名称对应起来
    class NodeLookup(object):
      def __init__(self,
                   label_lookup_path=None,
                   uid_lookup_path=None):
        if not label_lookup_path:
          label_lookup_path = os.path.join(
              FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
        if not uid_lookup_path:
          uid_lookup_path = os.path.join(
              FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
        self.node_lookup = self.load(label_lookup_path, uid_lookup_path)

    # 为每一个softmax节点读取人类可读的类标英文名字
      def load(self, label_lookup_path, uid_lookup_path):
        if not tf.gfile.Exists(uid_lookup_path):
          tf.logging.fatal('File does not exist %s', uid_lookup_path)
        if not tf.gfile.Exists(label_lookup_path):
          tf.logging.fatal('File does not exist %s', label_lookup_path)

        # Loads mapping from string UID to human-readable string
        proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
        uid_to_human = {}
        p = re.compile(r'[nd]*[ S,]*')
        for line in proto_as_ascii_lines:
          parsed_items = p.findall(line)
          uid = parsed_items[0]
          human_string = parsed_items[2]
          uid_to_human[uid] = human_string

        # Loads mapping from string UID to integer node ID.
        node_id_to_uid = {}
        proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
        for line in proto_as_ascii:
          if line.startswith('  target_class:'):
            target_class = int(line.split(': ')[1])
          if line.startswith('  target_class_string:'):
            target_class_string = line.split(': ')[1]
            node_id_to_uid[target_class] = target_class_string[1:-2]

        # Loads the final mapping of integer node ID to human-readable string
        node_id_to_name = {}
        for key, val in node_id_to_uid.items():
          if val not in uid_to_human:
            tf.logging.fatal('Failed to locate: %s', val)
          name = uid_to_human[val]
          node_id_to_name[key] = name
        return node_id_to_name


      def id_to_string(self, node_id):
        if node_id not in self.node_lookup:
          return ''
        return self.node_lookup[node_id]


    # 从protocol buffer文件中反序列化出inception-v3模型及参数
    def create_graph():
      # Creates graph from saved graph_def.pb.
      with tf.gfile.FastGFile(os.path.join(
          FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')

       
    # 使用v3模型对image图片进行分类,并输出top5置信度的类别预测
    def run_inference_on_image(image):
      if not tf.gfile.Exists(image):
        tf.logging.fatal('File does not exist %s', image)
      image_data = tf.gfile.FastGFile(image, 'rb').read()

      # Creates graph from saved GraphDef.
      create_graph()

      with tf.Session() as sess:
        # Some useful tensors:
        # 'softmax:0': A tensor containing the normalized prediction across
        #   1000 labels.
        # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
        #   float description of the image.
        # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
        #   encoding of the image.
        # Runs the softmax tensor by feeding the image_data as input to the graph.

        softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
        predictions = sess.run(softmax_tensor,
                               {'DecodeJpeg/contents:0': image_data})
        predictions = np.squeeze(predictions)

        # Creates node ID --> English string lookup.
        node_lookup = NodeLookup()
        top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
        for node_id in top_k:
          human_string = node_lookup.id_to_string(node_id)
          score = predictions[node_id]
          print('%s (score = %.5f)' % (human_string, score))

    # 下载模型存档并解压
    def maybe_download_and_extract():
      dest_directory = FLAGS.model_dir
      if not os.path.exists(dest_directory):
        os.makedirs(dest_directory)
      filename = DATA_URL.split('/')[-1]
      filepath = os.path.join(dest_directory, filename)
      if not os.path.exists(filepath):

        def _progress(count, block_size, total_size):
          sys.stdout.write(' >> Downloading %s %.1f%%' % (
              filename, float(count * block_size) / float(total_size) * 100.0))
          sys.stdout.flush()
        filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
        print()
        statinfo = os.stat(filepath)
        print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
      tarfile.open(filepath, 'r:gz').extractall(dest_directory)


    def main(_):
      maybe_download_and_extract()
      image = (FLAGS.image_file if FLAGS.image_file else
               os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
      run_inference_on_image(image)


    if __name__ == '__main__':

      parser = argparse.ArgumentParser()
      # classify_image_graph_def.pb:
      #   Binary representation of the GraphDef protocol buffer.
      # imagenet_synset_to_human_label_map.txt:
      #   Map from synset ID to a human readable string.
      # imagenet_2012_challenge_label_map_proto.pbtxt:
      #   Text representation of a protocol buffer mapping a label to synset ID.
      parser.add_argument(
          '--model_dir',
          type=str,
          default='/tmp/imagenet',
          help="""
          Path to classify_image_graph_def.pb,
          imagenet_synset_to_human_label_map.txt, and
          imagenet_2012_challenge_label_map_proto.pbtxt.
          """

      )

      parser.add_argument(
          '--image_file',
          type=str,
          default='',
          help='Absolute path to image file.'
      )

      parser.add_argument(
          '--num_top_predictions',
          type=int,
          default=5,
          help='Display this many predictions.'
      )

      FLAGS, unparsed = parser.parse_known_args()
      tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

        运行效果如下:

    图4 运行结果

        简要分析下吧。程序使用InceptionV3的存档模型对一张可爱的滚滚(熊猫)图片进行识别。结果显示该图片属于panda的概率为89.1%,Indir的概率为0.8%,lesser panda的概率为0.3%等等。

        哈哈哈哈,当把照片换成我家小天使时。小家伙你也太鼠头鼠脑的了吧哈哈哈:

        

  • 相关阅读:
    WEB前端开发工具的初识
    Linux常见问题的处理方法(长期更新)
    eclipse+tomcat开发web项目
    Android适配--百分比的适配
    TCP通信丢包原因总结
    android 实现类似微信缓存和即时更新好友头像
    是否应该开启WebView硬件加速?
    Android通过浏览器打开app页面并且传递值
    设计模式(五)责任链模式
    Android 最新 Support V4 包大拆分有用吗?
  • 原文地址:https://www.cnblogs.com/catallen/p/9146315.html
Copyright © 2011-2022 走看看