训练自己的数据集(以bottle为例):
1、准备数据
文件夹结构: models ├── images ├── annotations │ ├── xmls │ └── trainval.txt └── bottle ├── train_logs 训练文件夹 └── val_logs 日志文件夹
1)、下载官方预训练模型: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
以ssd_mobilenet_v1_coco
为例,将压缩包内model.ckpt*
的三个文件复制到bottle内
2)、准备jpg图片数据,放入images文件夹(图片文件命名要求“名字+下划线+编号.jpg”,必须使用下划线,编号从1开始)
使用https://github.com/tzutalin/labelImg工具对图片进行标注,生成xml文件放置xmls文件夹,并保持xml和jpg命名相同
3)、新建 bottle/trainval.txt 文件,内容为(图片名 1 1 1),每行一个文件,如:
bottle_1 1 1 1
bottle_2 1 1 1
4)、新建object_detection/data/bottle_label_map.pbtxt,内容如下
item { id: 1 name: 'bottle' }
2、生成数据
# From tensorflow/models python object_detection/create_pet_tf_record.py --label_map_path=object_detection/data/bottle_label_map.pbtxt --data_dir=`pwd` --output_dir=`pwd`
得到 pet_train.record
和 pet_val.record
移动至bottle文件夹
3、准备conf文件
复制object_detection/samples/configs/ssd_mobilenet_v1_pets.config
到 /bottle/ssd_mobilenet_v1_bottle.config
对ssd_mobilenet_v1_bottle.config文件进行一下修改:
修改第9行为 num_classes: 1,此数值代表bottle_label_map.pbtxt文件配置item的数量 修改第158行为 fine_tune_checkpoint: "bottle/model.ckpt" 修改第177行为 input_path: "bottle/pet_train.record" 修改第179行和193行为 label_map_path: "object_detection/data/bottle_label_map.pbtxt" 修改第191行为 input_path: "bottle/pet_val.record"
4、训练
# From tensorflow/models python object_detection/train.py --logtostderr --pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config --train_dir=bottle/train_logs 2>&1 | tee bottle/train_logs.txt &
5、验证
# From tensorflow/models python object_detection/eval.py --logtostderr --pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config --checkpoint_dir=bottle/train_logs --eval_dir=bottle/val_logs &
6、可视化log
可一边训练一边可视化训练的log,可看到Loss趋势。
tensorboard --logdir train_logs/
浏览器访问 ip:6006,可看到趋势以及具体image的预测结果
7、导出模型
# From tensorflow/models python object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path bottle/ssd_mobilenet_v1_bottle.config --trained_checkpoint_prefix bottle/train_logs/model.ckpt-8 --output_directory bottle
生成 bottle/frozen_inference_graph.pb
文件
8、测试图片
运行object_detection_tutorial.ipynb并修改其中的各种路径即可
或自写编译inference脚本,如tensorflow/models/object_detection/infer.py:
import sys sys.path.append('..') import os import time import tensorflow as tf import numpy as np from PIL import Image from matplotlib import pyplot as plt from utils import label_map_util from utils import visualization_utils as vis_util PATH_TEST_IMAGE = sys.argv[1] PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb' PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt' NUM_CLASSES = 21 IMAGE_SIZE = (18, 12) label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') config = tf.ConfigProto() config.gpu_options.allow_growth = True with detection_graph.as_default(): with tf.Session(graph=detection_graph, config=config) as sess: start_time = time.time() print(time.ctime()) image = Image.open(PATH_TEST_IMAGE) image_np = np.array(image).astype(np.uint8) image_np_expanded = np.expand_dims(image_np, axis=0) image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') boxes = detection_graph.get_tensor_by_name('detection_boxes:0') scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') (boxes, scores, classes, num_detections) = sess.run( [boxes, scores, classes, num_detections], feed_dict={image_tensor: image_np_expanded}) print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time)) vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) plt.figure(figsize=IMAGE_SIZE) plt.imshow(image_np)
运行infer.py test_images/image1.jpg
即可