zoukankan      html  css  js  c++  java
  • 物体检测项目

    1、项目介绍

    1.1 项目架构设计        

            实现基于tensorflow的物体检测。项目框架主要分为三部分:数据采集层、深度模型层、用户层。其中,数据采集层用于对数据进行标记以及转换成TFRecords格式数据文件。深度模型层的功能是读取数据采集层输出的TFRecords数据进行数据的预处理以及对深度模型的训练,其中深度模型可以使用不同的框架(例如SSD、YOLO等),通过模型工厂进行选择,本项目中使用SSD物体检测框架。训练得到的模型通过tensorflow serving进行部署,提供给后台。用户层通过前端和后台业务交互得到想要的结果。项目结构如下:

                                                      图1 物体检测项目框架

    使用TensorflowServing进行模型部署有以下几个好处:

    1、可以进行模型的热更新:只要上传模型文件到服务器上即可,TensorFlow会自动识别模型并使用,不需要重启serving 服务。

    2、导出模型和使用模型进行解耦合

                                                                                            图2 TensorflowServing模型部署逻辑

     整个项目开发流程主要分为两大部分:

    1.模型的训练与测试

        训练

            数据集处理(将数据转换成TFRecords格式文件)

            数据读取

            preprocess(数据预处理)

            网络构建预测结果

            损失计算并训练

            模型保存

        测试
            测试数据

            preprocess(数据预处理)

            模型加载

            postprocess(预测结果后期处理)

            预测结果显示(matplotlib)

    2、模型部署与小程序

        模型导出

        TensorFlow Serving部署模型

        Serving客户端+Flask Web

        小程序前端

    1.2 项目代码训练架构设计

                                                         图2 项目代码训练架构设计

    其中:

    1.数据集工厂(data factory)

    为了使项目能够读取不同的数据集

    2.预处理工厂(preprocess factory)

    为了处理不同模型要求的处理需求

    3.模型工厂(model factory)

    为了项目训练数据能够使用不同的模型

    1.3 训练代码架构设计意义

    1.网络模型和网络模型之间不交叉,模型和数据之间解耦合,数据集与预处理逻辑之间解耦合;

    2.训练代码可以调用不同的模型与不同的数据集训练不同的模型结果。 

    2. 数据模块接口
            获取到的图片数据集,保存在IMAGE/commodity/JPEGImages文件下。使用图片标记工具(本项目使用labelimg)将图片进行标记,输出XML格式文件,保存在 IMAGE/commodity/Annotatons文件下。这样的数据集类似PASCAL VOC数据集,数据集的图片和标记文件分布在不同的文件中,并且图片和标签没有一一对应,后续项目中不方便处理,也不方便项目的解耦合。tensorflow提供了TFRecord个数来统一存储数据,TFRecord格式是一种将图像数据和标签数据存放在一起的二进制文件,在tensorflow中能够快速处理。因此项目中需要将数据集转换成TFRecords文件。TFRecord文件中的数据是通过tf.train.Example Protocol Buffer格式存储的。每个我想ample对应一张图片,其中包括图片的各种信息。特点是:

    1)体积小,消息大小只需要xml文件的1/10~1/3;

    2)解析速度快:解析速度比xml块20~100倍。

    其中,tf.train.Example的定义见本博客的《TFRecord数据处理》一节。

    2.1 数据转换成TFRecord格式文件

    2.1.1 转换步骤:

    1)设定每个tfrecord文件中保存多的样本个数

    2)读取每张图片内容以及xml文件

    3)将每次去读内容写入tfrecord文件

    2.1.2 数据转换成TFRecord文件

    代码结构如图所示:

                图3 图片转换成tfrecord文件

            其中,datasets文件夹下的utils存放读取数据的公用组件;dataset_config.py存放数据读取的配置;dataset_to_tfrecords.py为主要的数据转换逻辑。dataset_to_tfrecord.py文件执行dataset_to_tfrecords.py中的run()函数完成数据转换。具体代码如下:

    2.1.2.1 配置文件dataset_config.py如下:

    """
    数据集转换配置文件
    """
    
    # 指定原始图片的XML和图片的文件夹名字
    DIRECTORY_ANNOTATIONS = "Annotations/"
    DIRECTORY_IMAGES = "JPEGImages/"
    
    # 指定每个TFRecord文件存储example的数量
    SAMPLER_PER_FILES = 200
    
    # 定义字典,保存数据集的类别
    # 字典的key是类别,字典的value是一个元组
    # 元组的元素不能修改,元组中是类别代表的数字和类别
    VOC_LABELS = {
        'none': (0, 'Background'),
        'clothes': (1, 'clothes'),
        'pants': (2, 'pants'),
        'shoes': (3, 'shoes'),
        'watch': (4, 'watch'),
        'phone': (5, 'phone'),
        'audio': (6, 'audio'),
        'computer': (7, 'computer'),
        'books': (8, 'books')
    }

    2.1.2.2 utils文件下的dataset_utils.py文件中,编写定义tf Example需要的feature转换公式,代码如下:

    import tensorflow as tf
    
    
    # 生成整数型的属性
    def int64_feature(value):
        if not isinstance(value, list):
            value = [value]
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    
    
    # 生成浮点型的属性
    def float_feature(value):
        if not isinstance(value, list):
            value = [value]
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    
    
    # 生成字符串类型的属性
    def bytes_feature(value):
        if not isinstance(value, list):
            value = [value]
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

    2.1.2.3 dataset_to_tfrecords.py文件下主要编写编写转换逻辑,代码如下:

    import tensorflow as tf
    import os
    import xml.etree.ElementTree as ET
    from datasets.dataset_config import DIRECTORY_ANNOTATIONS, DIRECTORY_IMAGES, SAMPLER_PER_FILES, VOC_LABELS
    from datasets.utils.dataset_utils import int64_feature, float_feature, bytes_feature
    
    
    # 获取输出的TFRecord文件名字,格式如下:commodity_2018_train_xxx.tfrecord
    # xxx代表序号,从000开始
    def _get_output_filename(outputdir, dataset_name, fdx):
        """
        获取输出的TFRecord文件的名字
        :param outputdir: 输出路径
        :param dataset_name: 数据集名字
        :param fdx: 文件id
        :return:
        """
        return "%s/%s_%03d.tfrecord" % (outputdir, dataset_name, fdx)
    
    
    def _process_image(dataset_dir, image_name):
        """
        处理一张图片的数据:获取图片数据以及xml文件中的内容。根据需要获取
        :param dataset_dir: 数据集路径
        :param img_name: 图片名字
        :return:
        """
        # 图片路径 + 图片名字
        filename = dataset_dir + DIRECTORY_IMAGES + image_name + '.jpg'
    
        # 读取图片数据
        image_data = tf.gfile.FastGFile(filename, 'rb').read()
    
        # 读取xml数据,使用ET工具
        # 构造xml文件名字
        filename_xml = dataset_dir + DIRECTORY_ANNOTATIONS + image_name + '.xml'
    
        # 将文件内容转换成树状结构tree
        tree = ET.parse(filename_xml)
    
        # 获取root节点
        root = tree.getroot()
    
        # 获取root节点下面的子节点
        # 1、获取size信息
        size = root.find('size')
        # 把height、width、depth存放在一个shape里面
        shape = [int(size.find('height').text),
                 int(size.find('width').text),
                 int(size.find('depth').text)]
    
        # 用于存储object对应的label的编号
        labels = []
        labels_text = []
        difficults = []
        truncated = []
        bboxes = []
    
        # 2、获取 object信息
        for obj in root.findall('object'):
            # 解析每一个object,包含name、difficult、truncated、bndbox[xmin, ymin, xmax, ymax]
            # 取出label和与之对应的数字
            label = obj.find('name').text
            labels.append(int(VOC_LABELS[label][0]))
            labels_text.append(label.encode('ascii'))
    
            # 取出difficult
            if obj.find('difficult'):
                difficults.append(int(obj.find('difficult').text))
            else:
                # 不存在,默认difficult为0
                difficults.append(0)
    
            # 取出truncated
            if obj.find('truncated'):
                truncated.append(int(obj.find('truncated').text))
            else:
                # 不存在,默认truncated为0
                truncated.append(0)
    
            # 取出bndbox
            bbox = obj.find('bndbox')
            bboxes.append([float(bbox.find('ymin').text)/shape[0],
                           float(bbox.find('xmin').text) / shape[1],
                           float(bbox.find('ymax').text) / shape[0],
                           float(bbox.find('xmax').text) / shape[1]])
        return image_data, shape, labels, labels_text, difficults, truncated, bboxes
    
    
    def _convert_to_example(image_data, shape, labels, labels_text, difficults, truncated, bboxes):
        """
        将图片数据转换成example protocol buffer格式
        :param image_data:
        :param shape:
        :param labels:
        :param difficults:
        :param truncated:
        :param bboxes:
        :return:
        """
        # bboxes存储格式如下:[[a0, b0, c0, d0], [a1, b1, c1, d1]]转换成
        # ymin[a0, a1], xmin[b0, b1], ymax[c0, c1], xmax[d0, d1]
        ymin = []
        xmin = []
        ymax = []
        xmax = []
    
        for b in bboxes:
            ymin.append(b[0])
            xmin.append(b[1])
            ymax.append(b[2])
            xmax.append(b[3])
    
        # 将所有信息封装成example
        image_format = b'JPEG'
        example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': int64_feature(shape[0]),
            'image/width': int64_feature(shape[1]),
            'image/channels': int64_feature(shape[2]),
            'image/shape': int64_feature(shape),
            'image/object/bbox/ymin': float_feature(ymin),
            'image/object/bbox/xmin': float_feature(xmin),
            'image/object/bbox/ymax': float_feature(ymax),
            'image/object/bbox/xmax': float_feature(xmax),
            'image/object/bbox/label': int64_feature(labels),
            'image/object/bbox/difficult': int64_feature(difficults),
            'image/object/bbox/truncated': int64_feature(truncated),
            'image/object/bbox/label_text': bytes_feature(labels_text),
            'image/format': bytes_feature(image_format),
            'image/encoded': bytes_feature(image_data)}))
        return example
    
    
    def _add_to_tfrecord(dataset_dir, image_name, tfrecord_writer):
        """
        添加一个图片文件和xml内容写入文件中
        :param dataset_dir: 数据集目录
        :param img_name: 图片名
        :param tfrecord_writer: 文件写入实例
        :return:
        """
        # 1、读取每张图片内容及其对应的xml文件的内容
        image_data, shape, labels, labels_text, difficults, truncated, bboxes = _process_image(dataset_dir, image_name)
    
        # 2、将每张图片的数据封装成一个example
        example = _convert_to_example(image_data, shape, labels, labels_text, difficults, truncated, bboxes)
    
        # 3、使用tfrecord_writer将example序列化结果写入TFRecord文件
        tfrecord_writer.write(example.SerializeToString())
        return None
    
    
    def run(dataset_dir, output_dir, dataset_name="data"):
        """
        运行转换代码逻辑:存入tfrecord文件,每个文件固定N个样本
        :param dataset_dir: 数据集目录
        :param output_dir: TFRecord存储目录
        :param dataset_name: 数据集名字,指定名字以及train_or_test
        :return:
        """
        # 1、判断数据集目录是否存在,不存在则创建一个目录
        if not tf.gfile.Exists(dataset_dir):
            tf.gfile.MakeDirs(dataset_dir)
        # 2、读取某个文件夹下的所有文件名字列表
        path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)
    
        # 读取所有文件,返回所有文件名字列表。但是会打乱顺序,需要使用sorted函数进行排序
        filenames = sorted(os.listdir(path))
    
        # 3、循环遍历列表,每N张图片和XML信息存储到一个tfrecord文件中
        i = 0
        fdx = 0
        while i < len(filenames):
            # 1、创建TFRecord文件
            tf_filename = _get_output_filename(output_dir, dataset_name, fdx)
    
            # 每N个文件存储一次
            # 新建tfrecord的存储器
            with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
                j = 0
                while i < len(filenames) and j < SAMPLER_PER_FILES:
                    print("转换图片进度%d/%d" % (i+1, len(filenames)))
    
                    # 取出图片以及xml的名字
                    single_filename = filenames[i]
                    image_name = single_filename[:-4]
    
                    # 读取图片和xml内容,存入图片,每次构造一个图片文件存储指定文件
                    _add_to_tfrecord(dataset_dir, image_name, tfrecord_writer)
    
                    i += 1
                    j += 1
    
                # 每N个数据,文件id增加计数
                fdx += 1
        print("数据集 %s 转换成功" % dataset_name)

    2.1.2.4 dataset_to_tfrecords.py文件代码

    from datasets import dataset_to_tfrecords
    
    if __name__ == '__main__':
        dataset_to_tfrecords.run('./IMAGE/commodity/', './IMAGE/tfrecords/commodity_tfrecords/', 'commodity_2018_train')

    为了实现数据格式的转换,需要在图3的IMAGE文件夹下分别放置如下目录:

    commodity/Annotations/

    commodity/JPEGImages/

    tfrecords/commodity_tfrecords/

    其中,commodity/Annotations/路径下存放标记过的xml格式文件;commodity/JPEGImages/路径下存放于xml格式对应的图片数据;tfrecords/commodity_tfrecords/路径用于存放转换好的tfrecord格式数据。

    2.2 TFRecord格式文件读取

    TFRecord文件读取有两种方法:

    1)使用tensorflow进行实现

    2)使用tensorflow.slim库进行实现

    本项目使用tensorflow.slim进行实现,具体步骤如下:

    1、定义解码器decoder

    decoder = tf.slim.tfexample_decoder.TFExampleDecoder()

    其中,定义解码器时,需要制定两个参数:keys_to_features,和items_to_handlers两个字典参数。key_to_features这个字典需要和TFrecord文件中定义的字典项匹配。items_to_handlers中的关键字可以是任意值,但是它的handler的初始化参数必须要来自于keys_to_features中的关键字。

    2、定义dataset

    dataset= tf.slim.dataset.Dataset()

    其中,定义dataset时需要将datasetsource、reader、decoder、num_samples等参数

    3、定义provider

    provider = slim.dataset_data_provider.DatasetDataProvider

    其中,需要的参数为:dataset, num_readers, reader_kwargs, shuffle, num_epochs,common_queue_capacity,common_queue_min, record_key=',seed, scope等。

    4、调用provider的get方法

    获取items_to_handlers中定义的关键字

    5、利用分好的batch建立一个prefetch_queue

    6、prefetch_queue中有一个dequeue的op,每执行一次dequeue则返回一个batch的数据。

    具体代码如下(这里先只介绍到通过provider的get函数获取数据,后面步骤5和步骤6的队列处理先不介绍,在实际项目代码中会使用到):

    import os
    import tensorflow as tf
    
    
    slim = tf.contrib.slim
    
    
    def get_dataset(dataset_dir):
        """
        获取commodity2018数据集
        :param dataset_dir: 数据集目录
        :return: Dataset
        """
        # 1.准备 tf.slim.dataset.Dataset()的参数
        # 1.1第一个参数:dataset
        file_pattern = os.path.join(dataset_dir, "commodity_2018_train_*.tfrecord")
    
        # 1.2第二个参数:reader
        reader = tf.TFRecordReader
    
        # 1.3第三个参数:decoder
        # 创建decoder需要两个参数:keys_to_features和items_to_handlers
        # 1.3.1 定义keys_to_features,反序列化的格式
        keys_to_features = {
            'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
            'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
            'image/height': tf.FixedLenFeature([1], tf.int64),
            'image/width': tf.FixedLenFeature([1], tf.int64),
            'image/channels': tf.FixedLenFeature([1], tf.int64),
            'image/shape': tf.FixedLenFeature([3], tf.int64),
            'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
            'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
            'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
        }
    
        # 1.3.2 items_to_handlers,反序列化成高级的格式
        items_to_handlers = {
            'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
            'shape': slim.tfexample_decoder.Tensor('image/shape'),
            'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
            'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
            'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
            'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
        }
    
        # 1.3.3构造decoder
        decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    
        # 2.tf.slim.dataset.Dataset()并返回
        return slim.dataset.Dataset(data_sources=file_pattern,
                                    reader=reader,
                                    decoder=decoder,
                                    num_samples=88,
                                    items_to_descriptions={
                                        'image': 'A color image of varying height and width.',
                                        'shape': 'Shape of the image',
                                        'object/bbox': 'A list of bounding boxes, one per each object.',
                                        'object/label': 'A list of labels, one per each object.'
                                    },  # 数据集返回的格式描述字典
                                    num_classes=8)
    from datasets.dataset_init import commodity_2018
    import tensorflow as tf
    
    slim = tf.contrib.slim
    
    if __name__ == '__main__':
        # 获取dataset
        dataset = commodity_2018.get_dataset("./IMAGE/tfrecords/commodity_tfrecords/")
    
        # 通过provider取出数据
        provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset,
                                                                  num_readers=3)
    
        # 通过get方法获取指定名称的数据(名称在准备规范数据dataset时高级格式的名称,即items_to_handlers中定义的名称)
        [image, shape, bbox, label, difficult, truncated] = provider.get(
            ['image', 'shape', 'object/bbox', 'object/label', 'object/difficult', 'object/truncated'])
    
        print(image, shape, bbox, label, difficult, truncated)

    最后得到如下输出结果:

     

                                                                                                      图4 输出tfrecord文件

    2.3 数据模块接口——数据工厂的实现

    功能需求:

    1)原始数据集(图片+XML)转换成TFRecords文件格式

    2)读取TFRecords数据

    数据模块设计的目录如下:

                         图5 数据模块接口 

    其中:

    dataset_factory:数据模块工厂,找到不同的数据集读取逻辑;

    dataset_init:保存不同数据集的TFRecords格式读取功能;

    utils:数据模块的共用组件

    dataset_config‘:数据模块的一些数据集配置文件

    dataset_to_tfrecords:原始数据集格式转换逻辑

    2.3.1 格式转换

    上一节以及介绍了将数据集转换成TFRecord格式文件,这里就不再赘述。

    2.3.2 读取TFRecord文件数据

    2.3.2.1 读取代码框架设计

    数据模块需要实现对不同数据集类型进行读取操作,因此可以定义一个基类,同时不同数据集继承这个基类。类的设计如下:

                                                 图6 数据读取基类设计

    2.3.2.2 数据读取代码

    1.在dataset_utils.py中新建一个基类,该文件下的代码如下:

    import tensorflow as tf
    
    
    # 定义数据集TFRecord文件读取基类
    class TFRecordsReaderBase(object):
        """
        数据集读取基类
        """
        def __init__(self, param):
            # param是给不同数据集使用的属性配置
            self.param = param
    
        def get_dataset(self, train_or_test, dataset_dir):
            """
            获取数据
            :param train_or_test: 训练还是测试
            :param dataset_dir: 数据集目录
            :return:
            """
            return None

    2. 因为在读取TFRecord数据时,不同的数据集,都会有自己特有的参数(比如:文件名、样本数、类别数等)。因此在dataset_config.py文件中定义不同数据集的参数,作为继承类的参数。这里使用命名字典:

    """
    数据集读取
    """
    from collections import namedtuple
    
    # 创建命名字典,用于存放读取数据类中的param参数
    DataSetParams = namedtuple("DataSetParamters", ['FILE_PATTERN',
                                                    'NUM_CLASSES',
                                                    'SPLITS_TO_SIZES',
                                                    'ITEMS_TO_DESCRIPTIONS'
                                                    ])
    
    # 定义commodity_2018属性配置
    Cmd2018 = DataSetParams(
        FILE_PATTERN='commodity_2018_%s_*.tfrecord',
        NUM_CLASSES=8,
        SPLITS_TO_SIZES={
            'train': 88,
            'test': 0
        },
        ITEMS_TO_DESCRIPTIONS={
            'image': '图片数据',
            'shape': '图片形状',
            'object/bbox': '若干物体对象的bbox框组成的列表',
            'object/label': '若干物体对应的label编号'
        }
    )

    3. 继承基类来定义派生类用于处理不同数据集

    继承的基类存放在dataset/dataset_init/目录下。对于不同数据集,定义不同的文件继承基类,本项目值处理commodity数据集,因此仅创建commodity_2018.py继承基类,代码如下:

    import os
    import tensorflow as tf
    from datasets.utils import dataset_utils
    
    slim = tf.contrib.slim
    
    
    class CommodityTFRecords(dataset_utils.TFRecordsReaderBase):
        """
        商品数据集读取类
        """
        def __init__(self, param):
            self.param = param
    
        def get_dataset(self, train_or_test, dataset_dir):
            """
            获取commodity2018数据集
            :param train_or_test: train or test
            :param dataset_dir: 数据集目录
            :return:
            """
            # 参数检查,异常抛出
            if train_or_test not in ['train', 'test']:
                raise ValueError("训练/测试的名字 %s 错误" % train_or_test)
    
            if not tf.gfile.Exists(dataset_dir):
                raise ValueError("数据集目录 %s 不存在" % dataset_dir)
    
            # 1.准备 tf.slim.dataset.Dataset()的参数
            # 1.1第一个参数:dataset
            file_pattern = os.path.join(dataset_dir, self.param.FILE_PATTERN % train_or_test)
    
            # 1.2第二个参数:reader
            reader = tf.TFRecordReader
    
            # 1.3第三个参数:decoder
            # 创建decoder需要两个参数:keys_to_features和items_to_handlers
            # 1.3.1 定义keys_to_features,反序列化的格式
            keys_to_features = {
                'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
                'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
                'image/height': tf.FixedLenFeature([1], tf.int64),
                'image/width': tf.FixedLenFeature([1], tf.int64),
                'image/channels': tf.FixedLenFeature([1], tf.int64),
                'image/shape': tf.FixedLenFeature([3], tf.int64),
                'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
                'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
                'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
                'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
                'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
                'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
                'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
            }
    
            # 1.3.2 items_to_handlers,反序列化成高级的格式
            items_to_handlers = {
                'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
                'shape': slim.tfexample_decoder.Tensor('image/shape'),
                'object/bbox': slim.tfexample_decoder.BoundingBox(
                    ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
                'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
                'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
                'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
            }
    
            # 1.3.3构造decoder
            decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    
            # 2.tf.slim.dataset.Dataset()并返回
            return slim.dataset.Dataset(data_sources=file_pattern,
                                        reader=reader,
                                        decoder=decoder,
                                        num_samples=self.param.SPLITS_TO_SIZES[train_or_test],
                                        items_to_descriptions=self.param.ITEMS_TO_DESCRIPTIONS,  # 数据集返回的格式描述字典
                                        num_classes=self.param.NUM_CLASSES)

    2.3.3 定义数据工厂

    在datasets根目录下创建dataset_factory.py文件,定义数据工厂获取数据,代码如下:

    from datasets.dataset_init import commodity_2018
    from datasets.dataset_config import Cmd2018
    
    # 定义dataset种类的字典,目前只是有commodity数据集,后续可以添加
    datasets_maps = {
        'commodity_2018': commodity_2018.CommodityTFRecords
    }
    
    # 定义参数种类的字典,不同数据集,param参数不一样,目前只是有commodity的参数,后续可以添加
    param_map = {
        'commodity_2018': Cmd2018
    }
    
    
    def get_dataset(dataset_name, train_or_test, dataset_dir):
        """
        获取指定数据名称的数据文件
        :param dataset_name: 数据集名称(数据当中要存在
        :param train_or_test: train or test数据集
        :param dataset_dir: 数据集目录
        :return: Dataset 数据规范
        """
        if dataset_name not in datasets_maps:
            raise ValueError("数据集名称 %s 不存在" % dataset_name)
    
        param = param_map[dataset_name]
    
        return datasets_maps[dataset_name](param).get_dataset(train_or_test, dataset_dir)

    最终对外只提供dataset_factory.py文件用于读取TFRecord文件。

    3. 模型接口

    本项目使用SSD模型。

    项目文件结构如下:

              图7 网络模型接口文件格式

    其中的公共组件的源码都是已知的,本项目使用的ssd网络模型实现文件ssd_vgg_300.py相关代码都是现有代码。对于SSD模型以及其代码实现,将在另外章节介绍。

    3.1 网络工厂nets_factory实现

    类似数据工厂,我们定义模型工厂nets_factory.py文件,代码如下:

    from nets.nets_model import ssd_vgg_300
    
    nets_maps = {
        'ssd_vgg_300': ssd_vgg_300.SSDNet
    }
    
    
    def get_network(network_name):
        """
        获取不同网络模型
        :param network_name: 网络模型名称
        :return: 网络
        """
        if network_name not in nets_maps:
            raise ValueError("网络名称 %s 不存在" % network_name)
        
        return nets_maps[network_name]

    4.预处理模块

    目的:

    1)在图像的深度学习中,对输入数据进行数据增强(Data Augmentation),为了丰富图像的训练集,更好地提取图像特征,泛化模型(防止过拟合)。

    通过一系列图像的操作(比如:剪切、翻转、偏移、缩放等图像变换),增加数据集的大小,防止过拟合。

    2)还有一个根本目的就是把图片变成符合大小要求的格式:

    RCNN网络对于输入图片没有要求,但是网络当中卷积之前需要的大小为227×227;

    YOLO算法:输入图片大小为448×448;

    SSD算法:输入图片大小为300×300;

    4.1 预处理模块代码实现

    首先,预处理模块的结构如图所示:

                     图8 预处理模块结构

           

            其中,需要创建一个preprocessing目录,该目录下的文件用于数据预处理。该目录下的processing目录中的ssd_vgg_preprocessing.py是对于SSD模型的预处理的。如果后续需要增加网络模型,需要在这个文件夹下增加预处理的文件。utils中是预处理需要用到的公共组件。这些相关代码都是公开的代码,这里不做介绍。有了上面的基础文件,下面就来完成数据预处理工厂代码的编写,在preprocessing_factory.py文件中实现:

    from preprocessing.processing import ssd_vgg_preprocessing
    
    # 目前只有sdd_vgg_300,后续可以增加
    preprocessing_maps = {
        'ssd_vgg_300': ssd_vgg_preprocessing
    }
    
    
    def get_preprocessing(name, is_trainning=True):
        """
        预处理工厂获取不同的数据增强方法
        :param name: 预处理名称
        :param is_trainning: 是否是训练
        :return: 返回预处理的函数,后续再调用函数
        """
        if name not in preprocessing_maps:
            raise ValueError("数据预处理名称 %s 不存在" % name)
    
        # 定义一个预处理函数,用于函数返回,后续再调用该预处理函数
        def preprocessing_fn(image, labels, bboxes, out_shape,
                             data_format, **kwargs):
            return preprocessing_maps[name].preprocess_image(image, labels, bboxes, out_shape,
                                                             data_format=data_format,
                                                             is_training=is_trainning, **kwargs)
    
        return preprocessing_fn

    5.训练不同模块接口参数

    对于2、3、4章节,只是分别单独介绍了数据模块接口、模型接口以及数据预处理接口。现在需要统一每一个模块接口提供给训练的参数,整理成文档。这样以后就直接查看文档即可调用相关模块。总结如图9所示:

     

                                                                                                                       图9 训练不同模型参数

    6. 多GPU训练

    终于到了模型训练这一步了。这里介绍多GPU训练。

            对于深度学习来说,大量的计算量导致CPU会显得十分乏力耗时。所以需要GPU来进行提供帮助计算,那么他们的主要任务就是计算得出结果,与CPU之间会进行分工,CPU会做一些基本工作,变量存储,更新参数,输入数据变量等等。如图10所示。在TensorFlow当中会通过标号来区别不同的GPU和CPU,如 ,''/device:CPU:0", "/device:CPU:1","/device:GPU:0","/device:GPU:1","/device:GPU:2",那么这些标号都是程序自动给的编号,指的具体哪块计算设备。

                                                     图10 CPU与GPU之间的分工合作

    6.1 训练步骤

    • 步骤
      • 数据读取
      • preprocess(数据预处理)
      • 网络构建预测结果
      • 损失计算
      • 添加变量到TensorBoard
      • 模型训练、保存
    • 部署需求:训练整个模型需要在多GPU、多计算机的环境下进行

    那么接下来首先我们要讲模型训练的设备逻辑原理弄清楚,如图11所示:

                                                                                                  图11 模型训练的设备逻辑原理

           

            训练主要是在设备(GPU/CPU)上训练,但是如果我们利用目前简单的TensorFlow提供的API去进行指定设备训练会比较繁琐。所以在这里需要介绍一个TensorFlow提供的最新的专门用于多GPU,多计算机的设备部署模块——model_deploy。

    6.2 model_deploy介绍

    model_deploy位于TensorFlow slim模块的deployment目录下,可以使得用多个 GPU / CPU在同一台机器或多台机器上执行同步或异步训练变得更简单。可以从如下官方地址下载:

    https://github.com/tensorflow/models/blob/master/research/slim/deployment/model_deploy.py

    首先我们要介绍:

    replica:使用多机训练时,一台机器对应一个replica(复本);

    clone:由于tensorflow里多GPU训练一般都是每个GPU上都有完整的模型,各自进行前向传播计算,得到的梯度交给CPU平均后统一反向计算,每个GPU上的模型叫做一个clone;

    parameter server:多机训练时,计算梯度平均值并执行反向传播操作的参数,功能类似于单机多GPU的CPU;

    worker server:一般指单机多卡中的GPU,用于训练。

    6.2.1 DeploymentConfig

    1. DeploymentConfig为文件中的一个类,主要用于给变量配置选择的设备。

    • class DeploymentConfig(object):
      • 配置参数
      • num_clones=1:每一个计算设备上的模型克隆数(每台计算机的GPU/CPU总数)
      • clone_on_cpu=False:如果为True,将只在CPU上训练
      • replica_id=0:指定某个计算机去部署,默认第0台计算机(TensorFlow会给个默认编号)
      • num_replicas=1:多少台可用计算机
      • num_ps_tasks=0:用于参数服务器的计算机数量,0为不适用计算机作为参数服务器
      • worker_job_name='worker':工作服务器名称
      • ps_job_name='ps':参数服务器名称
    • config.variables_device()
      • 作为tf.device(func)的参数,返回默认创建变量的设备
      • 一般用于指定全局步数变量的设备,默认运行计算机的"/device:CPU:0"
    • config.inputs_device()
      • 作为tf.device(func)的参数,返回用于构建数据输入变量所在的设备。
      • 默认运行计算机的"/device:CPU:0"
    • config.optimizer_device()
      • 作为tf.device(func)的参数,返回学习率、优化器所在的设备。
      • 默认运行计算机的"/device:CPU:0"
    • config.clone_scope(self, clone_index):
      • 返回指定编号的设备命名空间
      • 按照这样编号,clone_0,clone_1...

    6.2.2 model_deploy定义的相关函数,主要用于为每一个clone创建一个复制的模型(在GPU)

    • model_deploy.create_clones(config, model_fn, args=None, kwargs=None):
      • 作用:每个clone创建一个复制的模型,给GPU进行clone模型
      • config:一个DeploymentConfig的配置对象
      • model_fn:用于回调的函数model_fn,
      • args=None, kwargs=None:回调函数model_fn的参数
      • 返回元组组成的列表,列表个数大小为指定的num_clones数量
        • Clone(outputs, scope, device)
          • outputs:网络模型的每一层节点
          • scope: 第i个GPU设备的命名空间,config.clone_scope(i)
          • clone_device:第i个GPU设备
    • model_deploy.optimize_clones(clones, optimizer,regularization_losses=None, **kwargs)
      • 作用:计算所有给定的clones的总损失以及每个需要优化的变量的总梯度
      • clones: 元组列表,每个元素Clone(outputs, scope, device)
      • optimizer:选择的优化器
      • **kwargs:可选参数,优化器优化的变量
      • 返回:
        • total_loss:总损失
        • grads_and_vars:每个需要优化变量的总梯度组成的列表

    源码介绍使用:

    # Set up DeploymentConfig
    config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True)
    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
        global_step = slim.create_global_step()
    # Define the inputs
    with tf.device(config.inputs_device()):
        images, labels = LoadData(...)
        inputs_queue = slim.data.prefetch_queue((images, labels))
    # Define the optimizer.
    with tf.device(config.optimizer_device()):
        optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
    
    
    # Define the model including the loss.
    def model_fn(inputs_queue):
        images, labels = inputs_queue.dequeue()
        predictions = CreateNetwork(images)
        slim.losses.log_loss(predictions, labels)
    
    
    model_dp = model_deploy.deploy(config, model_fn, [inputs_queue],
                                   optimizer=optimizer)
    # Run training.
    slim.learning.train(model_dp.train_op, my_log_dir,
                        summary_op=model_dp.summary_op)

    6.3 训练逻辑

    1)DeploymentConfig

        需要在训练之前配置所有的设备信息

        定义全局步数

    2)获取图片队列

        在config.inputs_device()指定

    3)数据输入、网络计算结果、定义损失并复制模型到clones,添加变量到tensorboard

        model_deploy.create_clones

    4)定义学习率、优化器

       config.optimizer_device()指定

    5)计算所有GPU/CPU设备的平均损失和每个变量的梯度总和、定义训练OP、summaries OP

        model_deploy.optimize_clones

    6)配置训练的config,进行训练

    slim.learning.train

    代码框架如下:

     

    图中,pre_trained文件下存放的是预训练好的ssd_vgg_300网络的预训练模型,fine_tuning是训练存放模型的路径。

    根目录下的utils是公共组件,最后训练的文件是train_ssd_network.py。

    训练代码如下:

    """
    训练初始化参数
    
    PRE_TRAINED_PATH=./ckpt/pre_trained/ssd_vgg_300.ckpt
    TRAIN_MODEL_PDIR=./ckpt/fine_tuning/
    DATASET_DIR=./IMAGE/tfrecords/commodity_tfrecords/
    
    每批次训练样本数:32或者更小
    惩罚项:0.005
    学习率:0.001
    优化器选择:adam
    模型名称:ssd_vgg_300
    """
    
    import tensorflow as tf
    from datasets import dataset_factory
    from preprocessing import preprocessing_factory
    from nets import nets_factory
    from utils import train_tools
    from deployment import model_deploy
    
    slim = tf.contrib.slim
    
    DATA_FORMAT = 'NHWC'
    
    # 命令行参数
    # 设备相关的命令行参数
    tf.app.flags.DEFINE_integer('num_clones', 1, "可用GPU数量")
    tf.app.flags.DEFINE_boolean('clone_on_cpu', False, "是否只在CPU上运行")
    tf.app.flags.DEFINE_integer('replica_id', 0, "复本id")
    
    # 数据集相关命令行参数
    tf.app.flags.DEFINE_string('dataset_dir', ' ', "训练数据集目录")
    tf.app.flags.DEFINE_string('dataset_name', 'commodity_2018', "数据集名称")
    tf.app.flags.DEFINE_string('train_or_test', 'train', "训练还是测试")
    
    # 网络相关命令行参数
    tf.app.flags.DEFINE_string('network_name', 'ssd_vgg_300', "网络名称")
    tf.app.flags.DEFINE_integer('batch_size', 32, "每批次获取样本换数量")
    tf.app.flags.DEFINE_float('weight_decay', 0.0001, "网络误差惩罚项")
    
    # 训练相关参数
    tf.app.flags.DEFINE_string(
        'optimizer', 'rmsprop', '优化器种类 可选"adadelta", "adagrad", "adam","ftrl", "momentum", "sgd" or "rmsprop".')
    tf.app.flags.DEFINE_string(
        'learning_rate_decay_type', 'exponential', '学习率种类 "fixed", "exponential", "polynomial".')
    tf.app.flags.DEFINE_float('learning_rate', 0.01, '模型初始学习率')
    tf.app.flags.DEFINE_float('end_learning_rate', 0.0001, '模型终止学习率')
    
    tf.app.flags.DEFINE_integer('max_number_of_steps', None, '训练的最大步数')
    tf.app.flags.DEFINE_string('train_model_dir', ' ', '训练输出的模型目录')
    tf.app.flags.DEFINE_string('pre_trained_model', None, '预训练模型目录')
    
    FLAGS = tf.app.flags.FLAGS
    
    
    def main(_):
    
        if not FLAGS.dataset_dir:
            raise ValueError("必须指定一个TFRecord的数据集目录")
    
        # 设置打印级别
        tf.logging.set_verbosity(tf.logging.DEBUG)
    
        # 在默认图中进行训练
        with tf.Graph().as_default():
            # 1.DeploymentConfig配置
            deploy_config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                                          clone_on_cpu=FLAGS.clone_on_cpu,
                                                          replica_id=0,
                                                          num_replicas=1,
                                                          num_ps_tasks=0)
    
            # 在variables_device定义全局步长(网络训练一般都这么配置)
            with tf.device(deploy_config.variables_device()):
                global_step = tf.train.create_global_step()
    
            # 2.获取图片数据,做一些预处理
            # image, shape, bbox, label
            # 不是直接进行训练,而是需要进行正负样本标记(输出的anchor和GT进行IOU计算选择)
    
            # 2.1步骤如下:
            # (1)通过数据工厂获取DataSet规范,不是真正的数据,需要通过后续操作去获取数
            dataset = dataset_factory.get_dataset(dataset_name=FLAGS.dataset_name,
                                                  train_or_test=FLAGS.train_or_test,
                                                  dataset_dir=FLAGS.dataset_dir)
    
            # (2)通过网络计算获取的anchors结果
            # 通过网络工厂获取网络
            ssd_class = nets_factory.get_network(FLAGS.network_name)
    
            # 获取默认网络参数
            ssd_params = ssd_class.default_params._replace(num_classes=9)
    
            # 初始化网络init函数
            ssd_net = ssd_class(ssd_params)
    
            # 获取shape
            ssd_shape = ssd_net.params.img_shape
    
            # 获取anchors, SSD网络中6层的所有计算出来的默认候选框default boxes
            ssd_anchors = ssd_net.anchors(ssd_shape)
    
            # (3)获取预处理函数
            image_preprocessing_fn = preprocessing_factory.get_preprocessing(name=FLAGS.network_name,
                                                                             is_training=True)
    
            # 打印网络相关参数
            train_tools.print_configuration(ssd_params, dataset.data_sources)
    
            # 2.2
            # (1)通过slim.dataset_data_provider.DatasetDataProvider获取图像数据
            # (2)进行数据预处理
            # (3)对获取出来的GT标签和bbox进行编码
            # (4)获取的单个样本数据,要进行批处理以及返回队列
            with tf.device(deploy_config.inputs_device()):
                with tf.name_scope(FLAGS.network_name + "_data_provider"):
                    provider = slim.dataset_data_provider.DatasetDataProvider(
                        dataset,
                        num_readers=4,
                        common_queue_capacity=20 * FLAGS.batch_size,
                        common_queue_min=10 * FLAGS.batch_size,
                        shuffle=True)
    
                    # get获取数据(真正获取参数)
                    [image, shape, glabels, gbboxes] = provider.get(['image', 'shape', 'object/label', 'object/bbox'])
    
                    # 数据预处理 [?, ?, 3]-->[300, 300, 3]
                    image, glabels, gbboxes = image_preprocessing_fn(image, glabels, gbboxes, ssd_shape, DATA_FORMAT)
    
                    # 原始anchor boxes进行正负样本标记
                    # gclasses: 目标类别
                    # glocalizations: 目标类别的真实位置
                    # gscores: 目标结果(概率值)
                    gclasses, glocalizations, gscores = ssd_net.bboxes_encode(glabels, gbboxes, ssd_anchors)
    
                    # 批处理、队列处理
                    # tensor_list:tensor组成的类别 [tensor, tensor, tensor, ...]
                    # r是1个tensor组成的列表
                    r = tf.train.batch(tensors=train_tools.reshape_list([image, gclasses, glocalizations, gscores]),
                                       batch_size=FLAGS.batch_size,
                                       num_threads=4,
                                       capacity=5 * FLAGS.batch_size)
    
                    batch_queue = slim.prefetch_queue.prefetch_queue(r, capacity=deploy_config.num_clones)
    
            # 3.数据输入、网络计算结果、定义损失并复制模型到clones,添加变量到tensorboard
            summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
            # batch_shape:获取的默认队列大小,即上面r的大小
            batch_shape = [1] + 3 * [len(ssd_anchors)]
            update_ops, first_clone_scope, clones = train_tools.deploy_loss_summary(deploy_config,
                                                                                    batch_queue,
                                                                                    ssd_net,
                                                                                    summaries,
                                                                                    batch_shape,
                                                                                    FLAGS)
    
            # 4.定义学习率、优化器
            # 初始学习率:0.001
            # 终止学习率:0.0001
            # 优化器选择:adam
            with tf.device(deploy_config.optimizer_device()):
                # 定义学习率和优化器
                learning_rate = train_tools.configure_learning_rate(FLAGS, dataset.num_samples, global_step)
    
                # 定义优化器
                optimizer = train_tools.configure_optimizer(FLAGS, learning_rate)
    
                # 观察学习的变化情况添加到summaries中
                summaries.add(tf.summary.scalar('learning_rate', learning_rate))
    
            # 5.计算所有GPU/CPU设备的平均损失和每个变量的梯度总和、定义训练OP、summaries OP
            train_op, summaries_op = train_tools.get_trainop(optimizer,
                                                             summaries,
                                                             clones,
                                                             global_step,
                                                             first_clone_scope, update_ops)
    
            # 6.配置训练的config,进行训练
            # 6.1 配置config和saver
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
            config = tf.ConfigProto(log_device_placement=False,  # 若果打印会有许多变量的设备信息出现
                                    gpu_options=gpu_options)
    
            saver = tf.train.Saver(max_to_keep=5,  # 默认保留最近几个模型文件
                                   keep_checkpoint_every_n_hours=1.0,
                                   write_version=2,
                                   pad_step_number=False)
    
            # 6.2 训练
            slim.learning.train(
                train_op,  # 训练优化器tensor
                logdir=FLAGS.train_model_dir,  # 模型存储目录
                master='',
                is_chief=True,
                init_fn=train_tools.get_init_fn(FLAGS),  # 初始化参数的逻辑,预训练模型的读取和微调模型判断
                summary_op=summaries_op,  # 摘要
                number_of_steps=FLAGS.max_number_of_steps,  # 最大步数
                log_every_n_steps=10,  # 打印频率
                save_summaries_secs=60,  # 保存摘要频率
                saver=saver,  # 保存模型参数
                save_interval_secs=600,  # 保存模型间隔
                session_config=config,  # 会话参数配置
                sync_optimizer=None)
    
    
    if __name__ == '__main__':
        tf.app.run()

    训练模型:

    训练的过程使用技嘉RTX2070Super显卡。

    切换到ObjectDetection目录,执行如下命令(参数可以自己设定):

    PRE_TRAINED_PATH=./ckpt/pre_trained/ssd_300_vgg.ckpt
    TRAIN_MODEL_DIR=./ckpt/fine_tuning/
    DATASET_DIR=./IMAGE/tfrecords/commodity_tfrecords/
    python train_ssd_network.py --train_model_dir=${TRAIN_MODEL_DIR} --dataset_dir=${DATASET_DIR} --dataset_name="commodity_2018" --train_or_test=train --model_name=ssd_vgg_300 --pre_trained_path=${PRE_TRAINED_PATH} --weight_decay=0.0005 --optimizer=adam --learning_rate=0.001 --batch_size=16

    此时可以学习。

    同时在ckpt/fine_tuning文件夹下,执行如下命令,可以使用tensorboard查看已经添加到tensorboard中的相关参数。

    tensorboard --logdir=./

    训练过程如下图所示:

    7.测试过程

    7.1测试流程

    1)测试数据准备

    2)preprocessing数据预处理--测试过程的数据预处理就是需要图片的resize

    3)模型加载

    4)postprocess(预测结果后期处理)--训练过程中是不需要后期处理的

        通过scores筛选bbox

        使用NMS筛选box

        注意bbox边界与原始图片的bbox,按需修改bbox

    5)预测结果显示(使用matplotlib)

    7.1 测试框架:

    其中,test文件夹用于测试使用,visualization.py文件里面是显示结果的代码,test_image.py文件中文最终存放的测试代码。

    7.2 测试代码

    7.2.1显示图片代码

    visualization.py中的显示结果的代码如下:

    import cv2
    import random
    
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    import matplotlib.cm as mpcm
    
    VOC_LABELS = {
        '0': 'Background',
        '1': 'clothes',
        '2': 'pants',
        '3': 'shoes',
        '4': 'watch',
        '5': 'phone',
        '6': 'audio',
        '7': 'computer',
        '8': 'books'
    }
    
    # =========================================================================== #
    # Matplotlib 显示图
    # =========================================================================== #
    def plt_bboxes(img, classes, scores, bboxes, figsize=(10,10), linewidth=1.5):
        """显示bounding boxes.
        """
        fig = plt.figure(figsize=figsize)
        plt.imshow(img)
        height = img.shape[0]
        width = img.shape[1]
        colors = dict()
        for i in range(classes.shape[0]):
            cls_id = int(classes[i])
            if cls_id >= 0:
                score = scores[i]
                if cls_id not in colors:
                    colors[cls_id] = (random.random(), random.random(), random.random())
                ymin = int(bboxes[i, 0] * height)
                xmin = int(bboxes[i, 1] * width)
                ymax = int(bboxes[i, 2] * height)
                xmax = int(bboxes[i, 3] * width)
                rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                     ymax - ymin, fill=False,
                                     edgecolor=colors[cls_id],
                                     linewidth=linewidth)
                plt.gca().add_patch(rect)
                class_name = str(cls_id)
                plt.gca().text(xmin, ymin - 2,
                               '{:s} | {:.3f}'.format(VOC_LABELS[class_name], score),
                               bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                               fontsize=12, color='white')
    
        plt.show()

    7.2.1测试过程代码

    测试过程test_image.py代码如下:

    import numpy as np
    import tensorflow as tf
    from PIL import Image
    
    import sys
    sys.path.append('../')
    
    import matplotlib.pyplot as plt
    import matplotlib.image as mping
    import visualization
    from utils.basic_tools import np_methods
    
    slim = tf.contrib.slim
    
    from nets import nets_factory
    from preprocessing import preprocessing_factory
    
    # 1.定义输入图片数据的占位符
    image_input = tf.placeholder(tf.uint8, shape=[None, None, 3])
    
    # 定义输出形状,元组表示
    net_shape = (300, 300)
    
    data_format = 'NHWC'
    
    # 2.数据输入预处理工厂,进行预处理
    preprocessing_fn = preprocessing_factory.get_preprocessing('ssd_vgg_300', is_training=False)
    image_Pre, _, _, bbox_img = preprocessing_fn(image_input, None, None, net_shape, data_format)
    
    # image_Pre是三维形状--->(300, 300, 3)
    # 卷积神经网络要求都是四维的数据计算
    # 维度的扩充--->(1, 300, 300, 3)
    image_4d = tf.expand_dims(image_Pre, 0)
    
    # 3.定义SSD模型,并输出预测结果
    # 网络工厂获取
    ssd_class = nets_factory.get_network('ssd_vgg_300')
    ssd_params = ssd_class.default_params._replace(num_classes=9)
    
    reuse = True if 'ssd_net' in locals() else False
    
    # 初始化网络
    ssd_net = ssd_class(ssd_params)
    
    ssd_anchors = ssd_net.anchors(net_shape)
    
    # 通过网络的方法获取结果
    # 使用slim指定公有参数
    with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
        predictions, localizations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)
    
    
    config = tf.ConfigProto(log_device_placement=False)
    sess = tf.InteractiveSession(config=config)
    sess.run(tf.global_variables_initializer())
    
    ckpt_filepath = '../ckpt/fine_tuning/model.ckpt-103480'
    
    saver = tf.train.Saver()
    saver.restore(sess, ckpt_filepath)
    
    # 会话运行图片,输出结束
    # 读取一张图片
    img = Image.open('../IMAGE/commodity/JPEGImages/000080.jpg').convert('RGB')
    
    img = np.array(img)
    
    i, p, l, box_img = sess.run([image_4d, predictions, localizations, bbox_img], feed_dict={image_input:img})
    
    # 进行结果筛选
    classes, scores, bboxes = np_methods.ssd_bboxes_select(
        p, l, ssd_anchors, select_threshold=0.5, img_shape=(300, 300),
        num_classes=9, decode=True
    )
    
    # bbox边框不能超过原图片,默认原图的相对于bbox大小比例 [0, 0, 1, 1]
    bboxes = np_methods.bboxes_clip(box_img, bboxes)
    
    # 根据 scores 从大到小排序,并改变classes rbboxes的顺序
    classes, scores, bboxes = np_methods.bboxes_sort(classes, scores, bboxes, top_k=400)
    
    # 使用nms算法筛选bbox
    classes, scores, bboxes = np_methods.bboxes_nms(classes, scores, bboxes, nms_threshold=.45)
    
    # 根据原始图片的bbox,修改所有bbox的范围[.0, .0, .1, .1]
    bboxes = np_methods.bboxes_resize(box_img, bboxes)
    
    visualization.plt_bboxes(img, classes, scores, bboxes)

    测试中使用训练得到的ckpt/fine_tuning/model.ckpt-103480文件中的参数进行。测试结果如下图所示:

  • 相关阅读:
    9-day9-生成器-列表解析-三元表达式-
    8-day8-列表解析-装饰器-迭代器
    7-day7-闭包函数-装饰器-函数2
    6-day6-函数-1
    5-day5-字符编码-函数-文件操作
    hive 跨年周如何处理
    nginx 安装部署
    logstash 读取kafka output ES
    leedcode 001 之 Two Sum 42.20% Easy
    大数据调度与数据质量的重要性
  • 原文地址:https://www.cnblogs.com/xjlearningAI/p/12459468.html
Copyright © 2011-2022 走看看