zoukankan      html  css  js  c++  java
  • 使用Tensorflow训练自己的数据

    训练自己的数据集(以bottle为例):

     

    1、准备数据

    文件夹结构:
    models
    ├── images
    ├── annotations
    │ ├── xmls
    │ └── trainval.txt
    └── bottle
    ├── train_logs 训练文件夹
    └── val_logs 日志文件夹

    1)、下载官方预训练模型: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 
    ssd_mobilenet_v1_coco为例,将压缩包内model.ckpt*的三个文件复制到bottle内

    2)、准备jpg图片数据,放入images文件夹(图片文件命名要求“名字+下划线+编号.jpg”,必须使用下划线,编号从1开始) 
    使用https://github.com/tzutalin/labelImg工具对图片进行标注,生成xml文件放置xmls文件夹,并保持xml和jpg命名相同 
    3)、新建 bottle/trainval.txt 文件,内容为(图片名 1 1 1),每行一个文件,如:

    bottle_1 1 1 1
    bottle_2 1 1 1

    4)、新建object_detection/data/bottle_label_map.pbtxt,内容如下

    item {
        id: 1
        name: 'bottle'
    }
     

    2、生成数据

    # From tensorflow/models
    python object_detection/create_pet_tf_record.py 
    --label_map_path=object_detection/data/bottle_label_map.pbtxt 
    --data_dir=`pwd` 
    --output_dir=`pwd`

    得到 pet_train.record 和 pet_val.record 移动至bottle文件夹

     

    3、准备conf文件

    复制object_detection/samples/configs/ssd_mobilenet_v1_pets.config到 /bottle/ssd_mobilenet_v1_bottle.config 
    对ssd_mobilenet_v1_bottle.config文件进行一下修改:

    修改第9行为 num_classes: 1,此数值代表bottle_label_map.pbtxt文件配置item的数量
    修改第158行为 fine_tune_checkpoint: "bottle/model.ckpt"
    修改第177行为 input_path: "bottle/pet_train.record"
    修改第179行和193行为 label_map_path: "object_detection/data/bottle_label_map.pbtxt"
    修改第191行为 input_path: "bottle/pet_val.record"
     

    4、训练

    # From tensorflow/models
    python object_detection/train.py 
    --logtostderr 
    --pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config 
    --train_dir=bottle/train_logs 
    2>&1 | tee bottle/train_logs.txt &
     

    5、验证

    # From tensorflow/models
    python object_detection/eval.py 
    --logtostderr 
    --pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config 
    --checkpoint_dir=bottle/train_logs 
    --eval_dir=bottle/val_logs &
     

    6、可视化log

    可一边训练一边可视化训练的log,可看到Loss趋势。

    tensorboard --logdir train_logs/

    浏览器访问 ip:6006,可看到趋势以及具体image的预测结果

     

    7、导出模型

    # From tensorflow/models
    python object_detection/export_inference_graph.py 
    --input_type image_tensor 
    --pipeline_config_path bottle/ssd_mobilenet_v1_bottle.config 
    --trained_checkpoint_prefix bottle/train_logs/model.ckpt-8 
    --output_directory bottle

    生成 bottle/frozen_inference_graph.pb 文件

     

    8、测试图片

    运行object_detection_tutorial.ipynb并修改其中的各种路径即可 
    或自写编译inference脚本,如tensorflow/models/object_detection/infer.py:

    import sys
    sys.path.append('..')
    import os
    import time
    import tensorflow as tf
    import numpy as np
    from PIL import Image
    from matplotlib import pyplot as plt
    from utils import label_map_util
    from utils import visualization_utils as vis_util
    PATH_TEST_IMAGE = sys.argv[1]
    PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
    PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
    NUM_CLASSES = 21
    IMAGE_SIZE = (18, 12)
    label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
    categories = label_map_util.convert_label_map_to_categories(
    label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
    category_index = label_map_util.create_category_index(categories)
    detection_graph = tf.Graph()
    with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with detection_graph.as_default():
    with tf.Session(graph=detection_graph, config=config) as sess:
    start_time = time.time()
    print(time.ctime())
    image = Image.open(PATH_TEST_IMAGE)
    image_np = np.array(image).astype(np.uint8)
    image_np_expanded = np.expand_dims(image_np, axis=0)
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    scores = detection_graph.get_tensor_by_name('detection_scores:0')
    classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    (boxes, scores, classes, num_detections) = sess.run(
    [boxes, scores, classes, num_detections],
    feed_dict={image_tensor: image_np_expanded})
    print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
    vis_util.visualize_boxes_and_labels_on_image_array(
    image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
    category_index, use_normalized_coordinates=True, line_thickness=8)
    plt.figure(figsize=IMAGE_SIZE)
    plt.imshow(image_np)

    运行infer.py test_images/image1.jpg即可

  • 相关阅读:
    JS-记住用户名【cookie封装引申】
    JS-cookie封装
    JS-比较函数中嵌套函数,可以排序【对象数组】
    JS-随机div颜色
    JS-过滤敏感词【RegExp】
    JS-提取字符串—>>普通方法VS正则表达式
    CSS- ie6,ie7,ie8 兼容性写法,CSS hack写法
    JS-【同页面多次调用】轮播特效封装-json传多个参数
    JS-【同页面多次调用】tab选项卡封装
    Redis主从同步
  • 原文地址:https://www.cnblogs.com/leedaily/p/8286981.html
Copyright © 2011-2022 走看看