zoukankan      html  css  js  c++  java
  • yolo源码解析(二)

    五 读取数据pascal_voc.py文件解析

    我们在YOLENet类中定义了两个占位符,一个是输入图片占位符,一个是图片对应的标签占位符,如下:

    复制代码
            #输入图片占位符 [NONE,image_size,image_size,3]
            self.images = tf.placeholder(
                tf.float32, [None, self.image_size, self.image_size, 3],
                name='images')
            #设置标签占位符 [None,S,S,5+C]  即[None,7,7,25]
            self.labels = tf.placeholder(
                tf.float32,
                [None, self.cell_size, self.cell_size, 5 + self.num_class])        
    复制代码

    而pascal_voc.py文件的目的就是为了准备数据,赋值给占位符。在pascal_voc.py文件中定义了一个pascal_voc,该类包含了类初始化函数(__init__()),准备数据函数(prepare()),读取batch大小的图片以及图片对应的标签(get())等函数。

    复制代码
    import os
    import xml.etree.ElementTree as ET
    import numpy as np
    import cv2
    import pickle
    import copy
    import yolo.config as cfg
    
    
    '''
    VOC2012数据集处理
    '''
    
    class pascal_voc(object):
    复制代码

    1、类初始化函数

    复制代码
    '''
    VOC2012数据集处理
    '''
    
    class pascal_voc(object):
        '''
        VOC2012数据集处理的类,主要用来获取训练集图片文件,以及生成对应的标签文件
        '''
        def __init__(self, phase, rebuild=False):
            '''
            准备训练或者测试的数据
            
            args:
                phase:传入字符串 'train':表示训练
                                  'test':测试
                rebuild:是否重新创建数据集的标签文件,保存在缓存文件夹下
            '''
            #VOCdevkit文件夹路径
            self.devkil_path = os.path.join(cfg.PASCAL_PATH, 'VOCdevkit')
            #VOC2012文件夹路径
            self.data_path = os.path.join(self.devkil_path, 'VOC2012')
            #catch文件所在路径
            self.cache_path = cfg.CACHE_PATH
            #批大小
            self.batch_size = cfg.BATCH_SIZE
            #图像大小
            self.image_size = cfg.IMAGE_SIZE
            #单元格大小S
            self.cell_size = cfg.CELL_SIZE
            #VOC 2012数据集类别名
            self.classes = cfg.CLASSES
            #类别名->索引的dict
            self.class_to_ind = dict(zip(self.classes, range(len(self.classes))))
            ##图片是否采用水平镜像扩充训练集?
            self.flipped = cfg.FLIPPED
            #训练或测试?
            self.phase = phase
            #是否重新创建数据集标签文件
            self.rebuild = rebuild
            #从gt_labels加载数据,cursor表明当前读取到第几个
            self.cursor = 0
            #存放当前训练的轮数
            self.epoch = 1
            #存放数据集的标签 是一个list 每一个元素都是一个dict,对应一个图片 
            #如果我们在配置文件中指定flipped=True,则数据集会扩充一倍,每一张原始图片都有一个水平对称的镜像文件
            #      imname:图片路径 
            #      label:图片标签
            #      flipped:图片水平镜像?
            self.gt_labels = None
            #加载数据集标签  初始化gt_labels
            self.prepare()
    复制代码

    2、prepare()所有数据准备函数

    prepare()函数调用load_labels()函数,加载所有数据集的标签,保存在遍历gt_labels集合中,如果在配置文件中指定了水平镜像,则追加一倍的训练数据集。

    复制代码
        def prepare(self):
            '''
            初始化数据集的标签,保存在变量gt_labels中
            
            return:
                gt_labels:返回数据集的标签 是一个list  每一个元素对应一张图片,是一个dict                       
                                         imname:图片文件路径
                                         label:图片文件对应的标签 [7,7,25]的矩阵
                                         flipped:是否使用水平镜像? 设置为False
            '''
            #加载数据集的标签
            gt_labels = self.load_labels()
            #如果水平镜像,则追加一倍的训练数据集
            if self.flipped:
                print('Appending horizontally-flipped training examples ...')
                #深度拷贝
                gt_labels_cp = copy.deepcopy(gt_labels)
                #遍历每一个图片标签
                for idx in range(len(gt_labels_cp)):
                    #设置flipped属性为True
                    gt_labels_cp[idx]['flipped'] = True
                    #目标所在格子也进行水平镜像 [7,7,25]
                    gt_labels_cp[idx]['label'] =
                        gt_labels_cp[idx]['label'][:, ::-1, :]
                    for i in range(self.cell_size):
                        for j in range(self.cell_size):
                            #置信度==1,表示这个格子有目标
                            if gt_labels_cp[idx]['label'][i, j, 0] == 1:
                                #中心的x坐标水平镜像
                                gt_labels_cp[idx]['label'][i, j, 1] = 
                                    self.image_size - 1 -
                                    gt_labels_cp[idx]['label'][i, j, 1]
                #追加数据集的标签   后面的是由原数据集标签扩充的水平镜像数据集标签
                gt_labels += gt_labels_cp
            #打乱数据集的标签
            np.random.shuffle(gt_labels)
            self.gt_labels = gt_labels
            return gt_labels
    复制代码

    3、get()批量数据读取函数

    get()函数用在训练的时候,每次从gt_labels集合随机读取batch大小的图片以及图片对应的标签。

    复制代码
        def get(self):
            '''
            加载数据集 每次读取batch大小的图片以及图片对应的标签
            
            return:
                images:读取到的图片数据 [45,448,448,3]
                labels:对应的图片标签 [45,7,7,25]
            '''
            #[45,448,448,3]
            images = np.zeros(
                (self.batch_size, self.image_size, self.image_size, 3))
            #[45,7,7,25]
            labels = np.zeros(
                (self.batch_size, self.cell_size, self.cell_size, 25))
            count = 0
            #一次加载batch_size个图片数据
            while count < self.batch_size:
                #获取图片路径
                imname = self.gt_labels[self.cursor]['imname']
                #是否使用水平镜像?
                flipped = self.gt_labels[self.cursor]['flipped']
                #读取图片数据
                images[count, :, :, :] = self.image_read(imname, flipped)
                #读取图片标签
                labels[count, :, :, :] = self.gt_labels[self.cursor]['label']
                count += 1
                self.cursor += 1
                #如果读取完一轮数据,则当前cursor置为0,当前训练轮数+1
                if self.cursor >= len(self.gt_labels):
                    #打乱数据集
                    np.random.shuffle(self.gt_labels)
                    self.cursor = 0                
                    self.epoch += 1
            return images, labels
    复制代码

    4、image_read()函数读取图片

    图片读取函数,先读取图片,然后缩放,转换为RGB格式,再对数据进行归一化处理。

    复制代码
        def image_read(self, imname, flipped=False):
            '''
            读取图片
            
            args:
                imname:图片路径
                flipped:图片是否水平镜像处理? 
                
            return:
                image:图片数据 [448,448,3]
            '''
            #读取图片数据
            image = cv2.imread(imname)
            #缩放处理
            image = cv2.resize(image, (self.image_size, self.image_size))
            #BGR->RGB  uint->float32
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
            #归一化处理 [-1.0,1.0]
            image = (image / 255.0) * 2.0 - 1.0
            #宽倒序  即水平镜像
            if flipped:
                image = image[:, ::-1, :]
            return image
    复制代码

    5、load_labels()加载标签函数

    复制代码
        def load_labels(self):
            '''
            加载数据集标签
            
            return:
                gt_labels:是一个list  每一个元素对应一张图片,是一个dict                       
                                         imname:图片文件路径
                                         label:图片文件对应的标签 [7,7,25]的矩阵
                                         flipped:是否使用水平镜像? 设置为False   
            '''
            #缓冲文件名:即用来保存数据集标签的文件
            cache_file = os.path.join(
                self.cache_path, 'pascal_' + self.phase + '_gt_labels.pkl')
    
            #文件存在,且不重新创建则直接读取
            if os.path.isfile(cache_file) and not self.rebuild:
                print('Loading gt_labels from: ' + cache_file)
                with open(cache_file, 'rb') as f:
                    gt_labels = pickle.load(f)
                return gt_labels
    
            print('Processing gt_labels from: ' + self.data_path)
    
            #如果缓冲文件目录不存在,创建
            if not os.path.exists(self.cache_path):
                os.makedirs(self.cache_path)
                
            #获取训练测试集的数据文件名
            if self.phase == 'train':
                txtname = os.path.join(
                    self.data_path, 'ImageSets', 'Main', 'trainval.txt')
            #获取测试集的数据文件名
            else:
                txtname = os.path.join(
                    self.data_path, 'ImageSets', 'Main', 'test.txt')
            with open(txtname, 'r') as f:
                self.image_index = [x.strip() for x in f.readlines()]
    
            #存放图片的标签,图片路径,是否使用水平镜像?
            gt_labels = []
            #遍历每一张图片的信息
            for index in self.image_index:
                #读取每一张图片的标签label [7,7,25]
                label, num = self.load_pascal_annotation(index)
                if num == 0:
                    continue
                #图片文件路径
                imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')
                #保存该图片的信息
                gt_labels.append({'imname': imname,
                                  'label': label,
                                  'flipped': False})
            print('Saving gt_labels to: ' + cache_file)
            #保存
            with open(cache_file, 'wb') as f:
                pickle.dump(gt_labels, f)
            return gt_labels
    复制代码

    6、load_pascal_annotation()函数

    复制代码
        def load_pascal_annotation(self, index):
            """
            Load image and bounding boxes info from XML file in the PASCAL VOC
            format.
            
            args:
                index:图片文件的index
                
            return :
                label:标签 [7,7,25] 
                          0:1:置信度,表示这个地方是否有目标
                          1:5:目标边界框  目标中心,宽度和高度(这里是实际值,没有归一化)
                          5:25:目标的类别
                len(objs):objs对象长度
            """
            #获取图片文件名路径
            imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')
            #读取数据
            im = cv2.imread(imname)
            #宽和高缩放比例
            h_ratio = 1.0 * self.image_size / im.shape[0]
            w_ratio = 1.0 * self.image_size / im.shape[1]
            # im = cv2.resize(im, [self.image_size, self.image_size])
            #用于保存图片文件的标签
            label = np.zeros((self.cell_size, self.cell_size, 25))
            #图片文件的标注xml文件
            filename = os.path.join(self.data_path, 'Annotations', index + '.xml')
            tree = ET.parse(filename)
            objs = tree.findall('object')
    
            for obj in objs:
                bbox = obj.find('bndbox')
                # Make pixel indexes 0-based  当图片缩放到image_size时,边界框也进行同比例缩放
                x1 = max(min((float(bbox.find('xmin').text) - 1) * w_ratio, self.image_size - 1), 0)
                y1 = max(min((float(bbox.find('ymin').text) - 1) * h_ratio, self.image_size - 1), 0)
                x2 = max(min((float(bbox.find('xmax').text) - 1) * w_ratio, self.image_size - 1), 0)
                y2 = max(min((float(bbox.find('ymax').text) - 1) * h_ratio, self.image_size - 1), 0)
                #根据图片的分类名 ->类别index 转换
                cls_ind = self.class_to_ind[obj.find('name').text.lower().strip()]
                #计算边框中心点x,y,w,h(没有归一化)
                boxes = [(x2 + x1) / 2.0, (y2 + y1) / 2.0, x2 - x1, y2 - y1]
                #计算当前物体的中心在哪个格子中
                x_ind = int(boxes[0] * self.cell_size / self.image_size)
                y_ind = int(boxes[1] * self.cell_size / self.image_size)
                #表明该图片已经初始化过了
                if label[y_ind, x_ind, 0] == 1:
                    continue
                #置信度,表示这个地方有物体
                label[y_ind, x_ind, 0] = 1
                #物体边界框
                label[y_ind, x_ind, 1:5] = boxes
                #物体的类别
                label[y_ind, x_ind, 5 + cls_ind] = 1
    
            return label, len(objs)
    复制代码

    六 训练网络

    模型训练包含于train.py文件,Solver类的train()方法之中,训练部分只需要看懂了初始化参数,整个结构就很清晰了。

    复制代码
    import os
    import argparse
    import datetime
    import tensorflow as tf
    import yolo.config as cfg
    from yolo.yolo_net import YOLONet
    from utils.timer import Timer
    from utils.pascal_voc import pascal_voc
    
    slim = tf.contrib.slim
    
    '''
    用来训练YOLO网络模型
    '''
    
    class Solver(object):
        '''
        求解器的类,用于训练YOLO网络
        '''
    复制代码

    1、类初始化函数

    复制代码
       def __init__(self, net, data):
            '''
            构造函数,加载训练参数
            
            args:
                net:YOLONet对象
                data:pascal_voc对象
            '''
            #yolo网络
            self.net = net
            #voc2012数据处理
            self.data = data
            #检查点文件路径
            self.weights_file = cfg.WEIGHTS_FILE
            #训练最大迭代次数
            self.max_iter = cfg.MAX_ITER
            #初始学习率
            self.initial_learning_rate = cfg.LEARNING_RATE
            ##退化学习率衰减步数
            self.decay_steps = cfg.DECAY_STEPS
            #衰减率
            self.decay_rate = cfg.DECAY_RATE
            self.staircase = cfg.STAIRCASE
            ##日志文件保存间隔步
            self.summary_iter = cfg.SUMMARY_ITER
            ##模型保存间隔步
            self.save_iter = cfg.SAVE_ITER
            
            #输出文件夹路径
            self.output_dir = os.path.join(
                cfg.OUTPUT_DIR, datetime.datetime.now().strftime('%Y_%m_%d_%H_%M'))
            if not os.path.exists(self.output_dir):
                os.makedirs(self.output_dir)
            #保存配置信息
            self.save_cfg()
            #指定保存的张量 这里指定所有变量
            self.variable_to_restore = tf.global_variables()        
            self.saver = tf.train.Saver(self.variable_to_restore, max_to_keep=None)
            #指定保存的模型名称
            self.ckpt_file = os.path.join(self.output_dir, 'yolo.cpkt')
            #合并所有的summary
            self.summary_op = tf.summary.merge_all()
            #创建writer,指定日志文件路径,用于写日志文件
            self.writer = tf.summary.FileWriter(self.output_dir, flush_secs=60)
    
            #创建变量,保存当前迭代次数
            self.global_step = tf.train.create_global_step()
            #退化学习率
            self.learning_rate = tf.train.exponential_decay(
                self.initial_learning_rate, self.global_step, self.decay_steps,
                self.decay_rate, self.staircase, name='learning_rate')
            #创建求解器
            self.optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=self.learning_rate)
            # create_train_op that ensures that when we evaluate it to get the loss,
            # the update_ops are done and the gradient updates are computed.
            self.train_op = slim.learning.create_train_op(
                self.net.total_loss, self.optimizer, global_step=self.global_step)
    
            #设置GPU使用资源
            gpu_options = tf.GPUOptions()
            #按需分配GPU使用的资源
            config = tf.ConfigProto(gpu_options=gpu_options)
            self.sess = tf.Session(config=config)
            
            #运行图之前,初始化变量
            self.sess.run(tf.global_variables_initializer())
    
            #恢复模型
            if self.weights_file is not None:
                print('Restoring weights from: ' + self.weights_file)
                self.saver.restore(self.sess, self.weights_file)
    
            #将图写入日志文件
            self.writer.add_graph(self.sess.graph)
    复制代码

     2、train()训练函数

    复制代码
     def train(self):
            '''
            开始训练
            '''
            #训练时间
            train_timer = Timer()
            #数据集加载时间
            load_timer = Timer()
    
            #开始迭代
            for step in range(1, self.max_iter + 1):
                #计算每次迭代加载数据的起始时间
                load_timer.tic()
                #加载数据集 每次读取batch大小的图片以及图片对应的标签
                images, labels = self.data.get()
                #计算这次迭代加载数据集所使用的时间
                load_timer.toc()
                
                feed_dict = {self.net.images: images,
                             self.net.labels: labels}
    
                #迭代summary_iter次,保存一次日志文件,迭代summary_iter*10次,输出一次的迭代信息
                if step % self.summary_iter == 0:
                    if step % (self.summary_iter * 10) == 0:
                        #计算每次迭代训练的起始时间
                        train_timer.tic()
                        loss = 0.0001  
                        #开始迭代训练,每一次迭代后global_step自加1
                        summary_str, loss, _ = self.sess.run(
                            [self.summary_op, self.net.total_loss, self.train_op],
                            feed_dict=feed_dict)
                        #输出信息
                        log_str = '{} Epoch: {}, Step: {}, Learning rate: {}, Loss: {:5.3f}
    Speed: {:.3f}s/iter,Load: {:.3f}s/iter, Remain: {}'.format(
                            datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
                            self.data.epoch,
                            int(step),
                            round(self.learning_rate.eval(session=self.sess), 6),
                            loss,
                            train_timer.average_time,
                            load_timer.average_time,
                            train_timer.remain(step, self.max_iter))
                        print(log_str)
    
                    else:
                        #计算每次迭代训练的起始时间
                        train_timer.tic()           
                        #开始迭代训练,每一次迭代后global_step自加1
                        summary_str, _ = self.sess.run(
                            [self.summary_op, self.train_op],
                            feed_dict=feed_dict)  
                        #计算这次迭代训练所使用的时间
                        train_timer.toc()
                        
                    #将summary写入文件
                    self.writer.add_summary(summary_str, step)
    
                else:
                    #计算每次迭代训练的起始时间
                    train_timer.tic()
                    #开始迭代训练,每一次迭代后global_step自加1
                    self.sess.run(self.train_op, feed_dict=feed_dict)
                    #计算这次迭代训练所使用的时间
                    train_timer.toc()
    
                #没迭代save_iter次,保存一次模型
                if step % self.save_iter == 0:
                    print('{} Saving checkpoint file to: {}'.format(
                        datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
                        self.output_dir))
                    self.saver.save(
                        self.sess, self.ckpt_file, global_step=self.global_step)
    复制代码

    3、保存配置参数

    复制代码
        def save_cfg(self):
            '''
            保存配置信息
            '''
            with open(os.path.join(self.output_dir, 'config.txt'), 'w') as f:
                cfg_dict = cfg.__dict__
                for key in sorted(cfg_dict.keys()):
                    if key[0].isupper():
                        cfg_str = '{}: {}
    '.format(key, cfg_dict[key])
                        f.write(cfg_str)
    复制代码

    train.py文件除了上面介绍的求解器Solver这个类外,还包含了两个函数,一个是update_config_paths()函数,这个函数主要使用了设定数据集路径,以及检查点文件路径。

    复制代码
    def update_config_paths(data_dir, weights_file):
        '''
        数据集路径,和模型检查点文件路径
        
        args:
            data_dir:数据文件夹  数据集放在pascal_voc目录下  
            weights_file:检查点文件名 该文件放在数据集目录下的weights文件夹下
            
        '''
        cfg.DATA_PATH = data_dir                                                   #数据所在文件夹
        cfg.PASCAL_PATH = os.path.join(data_dir, 'pascal_voc')                     #VOC2012数据所在文件夹
        cfg.CACHE_PATH = os.path.join(cfg.PASCAL_PATH, 'cache')                    #保存生成的数据集标签缓冲文件所在文件夹
        cfg.OUTPUT_DIR = os.path.join(cfg.PASCAL_PATH, 'output')                   #保存生成的网络模型和日志文件所在的文件夹
        cfg.WEIGHTS_DIR = os.path.join(cfg.PASCAL_PATH, 'weights')                 #检查点文件所在的目录
    
        cfg.WEIGHTS_FILE = os.path.join(cfg.WEIGHTS_DIR, weights_file)
    复制代码

    我们主要来说一下另一个函数main()函数,先解析命令行参数,然后创建YOLONet、pascal_voc、Solver对象,最后开始训练。

    复制代码
    def main():
        #创建一个解析器对象,并告诉它将会有些什么参数。当程序运行时,该解析器就可以用于处理命令行参数。
        #https://www.cnblogs.com/lovemyspring/p/3214598.html
        parser = argparse.ArgumentParser()
        #定义参数
        parser.add_argument('--weights', default="YOLO_small.ckpt", type=str)    #权重文件名
        parser.add_argument('--data_dir', default="data", type=str)              #数据集路径
        parser.add_argument('--threshold', default=0.2, type=float)
        parser.add_argument('--iou_threshold', default=0.5, type=float)
        parser.add_argument('--gpu', default='', type=str)
        #定义了所有参数之后,你就可以给 parse_args() 传递一组参数字符串来解析命令行。默认情况下,参数是从 sys.argv[1:] 中获取
        #parse_args() 的返回值是一个命名空间,包含传递给命令的参数。该对象将参数保存其属性
        args = parser.parse_args()
    
        #判断是否是使用gpu
        if args.gpu is not None:
            cfg.GPU = args.gpu
    
        #设定数据集路径,以及检查点文件路径
        if args.data_dir != cfg.DATA_PATH  and args.data_dir is not None:
            update_config_paths(args.data_dir, args.weights)
    
        #设置环境变量
        os.environ['CUDA_VISIBLE_DEVICES'] = cfg.GPU
    
        #创建YOLO网络对象
        yolo = YOLONet()
        #数据集对象
        pascal = pascal_voc('train')
        #求解器对象
        solver = Solver(yolo, pascal)
    
        print('Start training ...')
        #开始训练
        solver.train()
        print('Done training.')
    复制代码

    我们执行如下代码,开始训练网络:

    if __name__ == '__main__':
        tf.reset_default_graph()
        # python train.py --weights YOLO_small.ckpt --gpu 0
        main()
  • 相关阅读:
    css 布局方式
    初识cv
    CSS 样式表{二}
    获取设备通讯录信息
    iOS Block界面反向传值小demo
    在iOS中如何正确的实现行间距与行高
    iOS开发- 获取本地视频文件
    view围绕圆心自转
    监测网络状态
    简单的九宫格算法与使用
  • 原文地址:https://www.cnblogs.com/sddai/p/10288096.html
Copyright © 2011-2022 走看看