zoukankan      html  css  js  c++  java
  • mmdection 测试代码(一个py搞定)

    mmdet开源项目,有很多人使用,但对于初学者,不知如何调用,尽管能够训练,却有时也为测试代码泛难,趁着项目需要,我个人编写了一份测试代码,共大家学习。

    准备:

    1.config parameters 参数py文件

    2.checkpoint  训练好的权重.pth文件

    3.classes  若无,则数字替代

    4.测试图片路径  该文件存放需要测试图片,也可以是多个文件下存放,代码会自动寻找

    详细如下:

    from mmdet.apis import inference_detector, init_detector
    import cv2
    import os
    import numpy as np
    import pandas as pd

    from tqdm import tqdm
    import time
    os.environ["CUDA_VISIBLE_DEVICES"] = "7"


    class Model():
    def __init__(self, root_config, root_checkpoint, **kwargs):
    self.model = init_detector(root_config, root_checkpoint) # 模型初始化
    self.thr_ok = kwargs.get('thr_ok', 0.05)
    self.classes = kwargs.get('classes', None)
    self.color = self.get_color()
    self.img_foramt = ['.jpg', '.JPG', '.bmp', '.png']

    def get_color(self):
    color = dict(red=(0, 0, 255),
    green=(0, 255, 0),
    blue=(255, 0, 0),
    cyan=(255, 255, 0),
    yellow=(0, 255, 255),
    magenta=(255, 0, 255),
    white=(255, 255, 255),
    black=(0, 0, 0))
    return color

    def model_test(self, result, img_name, classes, thr_ok=0.05):
    output_bboxes = []
    json_dict = []
    total_bbox = []
    for id, boxes in enumerate(result): # loop for categories
    category_id = id + 1
    if len(boxes) != 0:
    for box in boxes: # loop for bbox
    conf = box[4]
    if conf > thr_ok:
    total_bbox.append(list(box) + [category_id])

    bboxes = np.array(total_bbox)
    best_bboxes = bboxes
    output_bboxes.append(best_bboxes)
    for bbox in best_bboxes:
    coord = [round(i, 2) for i in bbox[:4]]
    conf = bbox[4]
    category = classes[int(bbox[5]) - 1] if classes is not None else int(bbox[5])
    json_dict.append({'img_name': img_name, 'cats': category, 'bbox': coord, 'score': conf})
    det_df = pd.DataFrame(json_dict, columns=['img_name', 'cats', 'bbox', 'score'])
    return det_df

    def single_test(self, root_img): # 单张图片模型测试
    img = cv2.imread(root_img)
    model_result = inference_detector(self.model, img)
    return model_result

    def run(self, img_root):
    model_result = self.single_test(img_root)
    img_name = self.get_strfile(img_root, pos=-1)
    result_df = self.model_test(model_result, img_name, self.classes, thr_ok=self.thr_ok)
    img_name_lst, cat_lst, box_lst, score_lst = self.pd2lst(result_df)
    return img_name_lst, cat_lst, box_lst, score_lst

    def pd2lst(self, result_df):
    img_name_lst, cat_lst, box_lst, score_lst = [], [], [], []
    if len(result_df) > 0:
    for i in range(len(result_df)):
    img_name_lst.append(result_df.loc[i]['img_name'])
    cat_lst.append(result_df.loc[i]['cats'])
    box_lst.append(result_df.loc[i]['bbox'])
    score_lst.append(result_df.loc[i]['score'])
    return img_name_lst, cat_lst, box_lst, score_lst

    def draw_bbox(self, img, cat_lst, box_lst, score_lst,
    bbox_color='green',
    text_color='green',
    thickness=1,
    font_scale=0.5
    ):
    for j, cat in enumerate(cat_lst):
    x1, y1, x2, y2 = np.array(box_lst[j]).astype(np.int32)
    bbox_color_new = self.color[bbox_color]
    cv2.rectangle(img, (x1, y1), (x2, y2), bbox_color_new, thickness=thickness)
    score = round(score_lst[j], 4)
    text_color_new = self.color[text_color]
    label_text = '{}:{}'.format(str(cat), str(score))
    cv2.putText(img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color_new)
    return img
    def get_strfile(self, file_str, pos=-1):
    # 得到file_str / or \ 的最后一个名称
    endstr_f_filestr = file_str.split('\')[pos] if '\' in file_str else file_str.split('/')[pos]
    return endstr_f_filestr

    def build_dir(self,out_dir):
    # 构建文件
    if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    return out_dir

    def show_img(self,img):
    import matplotlib.pyplot as plt
    plt.imshow(img)
    plt.show()

    def get_files_root(root):
    '''
    :return: 寻找root下面有文件夹的路径,输出所有文件夹绝对路径的列表
    '''
    files_lst = [root]
    result_lst = files_lst
    if build_files(root) == []:
    result_lst = files_lst
    else:
    is_while = True
    files_all_path = [file for file in files_lst]
    while is_while:
    for file_root in files_lst:
    F1 = build_files(file_root)
    for F1 in F1:
    files_all_path.append(F1)
    is_while = False
    # 排除主文件夹
    record = np.ones((len(files_all_path)))
    for i, F3 in enumerate(files_all_path):
    F3 = files_all_path[i]
    for j, F4 in enumerate(files_all_path):
    if F3 + '\' in F4 or F3 + '/' in F4:
    record[i] = 0
    break
    # 将需要循环聚集
    files_lst = []
    for i, F3 in enumerate(files_all_path):
    if record[i] == 1:
    files_lst.append(F3)
    # 判断是否有子文件夹
    for F4 in files_lst:
    file_judge = build_files(F4)
    if file_judge != []:
    is_while = True
    break

    result_lst = files_lst

    return result_lst

    def build_files(root):
    '''
    :得到该路径下的所有文件
    '''
    files = [os.path.join(root, file) for file in os.listdir(root)]
    files_true = []
    for file in files:
    if not os.path.isfile(file):
    files_true.append(file)
    return files_true


    def single_main(model,root_img,work_dir):
    # 一张图片测试所有集合
    img_name_lst, cat_lst, box_lst, score_lst = model.run(root_img) #
    img = cv2.imread(root_img)
    img = model.draw_bbox(img, cat_lst, box_lst, score_lst)
    file_name = model.get_strfile(root_img, pos=-2)
    out_file = model.build_dir(os.path.join(work_dir, file_name))
    img_name = model.get_strfile(root_img, pos=-1)
    cv2.imwrite(os.path.join(out_file, img_name), img)

    def main(root,model,work_dir):
    root_files=get_files_root(root)
    num=0
    for file_path in tqdm(root_files):
    for name in tqdm(os.listdir(file_path)):
    if name[-4:] in model.img_foramt:
    root_img=os.path.join(file_path,name)
    single_main(model,root_img,work_dir)
    num+=1

    print('num of images:',num)



    if __name__ == '__main__':
    root_config = '/data/sdv3/tangjun/xmtm/xmtm_pointer/code/model_new_meter/parameters.py'
    root_checkpoint = '/data/sdv3/tangjun/xmtm/xmtm_pointer/code/model_new_meter/model.pth'
    root = '/data/sdv3/tangjun/xmtm/xmtm_pointer/data/data_0512/data_step_two/train_step2/train' # 只需要测试文件夹
    work_dir='/data/sdv3/cj/First_Blood/tj/code/mmdet50/123/78/90'
    info = {'classes': None}

    time_start = time.time()
    model = Model(root_config, root_checkpoint, **info) # 类实列化,也是初始化
    # img_name_lst, cat_lst, box_lst, score_lst = model.run(root_img) #单张图片的预测

    main(root, model, work_dir)

    time_end=time.time()

    time_gap=time_end-time_start
    print('time gap:',time_gap)











    处理算法通用的辅助的code,如读取txt文件,读取xml文件,将xml文件转换成txt文件,读取json文件等
  • 相关阅读:
    SpringBoot-Maven打包压缩瘦身
    Docker安装Jenkins
    Spring Boot 微服务应用集成Prometheus + Grafana 实现监控告警
    Spring Boot构建 RESTful 风格应用
    SpringMVC 中 @ControllerAdvice 注解
    Spring Boot 整合 Freemarker
    Spring Boot中的静态资源文件
    SpringBoot配置文件 application.properties,yaml配置
    代码质量管理-安全问题
    8.Python基础 面向对象的基本概念
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/14847291.html
Copyright © 2011-2022 走看看