zoukankan      html  css  js  c++  java
  • 使用TensorFlow Object Detection API+Google ML Engine训练自己的手掌识别器

      上次使用Google ML Engine跑了一下TensorFlow Object Detection API中的Quick Start(http://www.cnblogs.com/take-fetter/p/8384564.html),但是遇到了很多错误,索性放弃了

    这两天直接开始从自己的数据集开始制作手掌识别器。先放运行结果吧

       

     所有代码文件可在https://github.com/takefetter/hand-detection查看,欢迎star和issue

    使用前所需要的准备:1.clone tensorflow models(site:https://github.com/tensorflow/models)

              2.在model/research目录下运行setup.py安装object detection API

              3.其余必要条件:安装tensorflow(版本需大于等于1.4),opencv-python等必须的package

              4.安装Google Cloud SDK,激活免费试用300美金(需要一张信用卡来验证)和在命令行中使用gcloud init设置等

    •  准备数据集

      (关于手的图片的dataset仍旧使用的dlib训练(site:http://www.cnblogs.com/take-fetter/p/8321158.html)中的Hand Images Databases - https://www.mutah.edu.jo/biometrix/hand-images-databases.html提供的数据集,只不过这次使用了WEHI系列的图片(MOHI的图片我也试过,导入后会导致standard-gpu版的训练无法进行(内存不足)),作为示例目前我只使用了1-50人的共计250张图片)

       tensorflow训练的数据集需为TFRecord格式,我们需要对训练数据进行标注,但是我并没有找到直接可以标注生成的工具,还好有工具可以生成Pascal VOC格式的xml文件      https://github.com/tzutalin/labelImg,推荐将图片文件放于research/images中,保存xml文件夹位于research/images/xmls中

    根据你要训练的数据集,创建.pbtxt文件

    • 转换为tfrecord格式

       完成图片标注后在xmls文件夹中运行xml_to_csv.py即可生成csv文件,再通过create_hand_tfrecord.py即可将图片转换为hand.record文件

       需要注意的是,如果你需要训练的数据集和我这里的不一样的话,create_hand_tfrecord.py的todo部分需要与你的.pbtxt文件内的内容一致

       (方法参考至https://github.com/datitran/raccoon_dataset 使用本作者的文件还可以完成划分测试集和分析数据等功能,当然我这里并没有使用)

    •  下载预训练模型

       重新开始一个模型的训练时间是很长的时间,而tensorflow model zoo为我们提供好了预训练的模型(site:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#coco-trained-models-coco-models),选择并下载一个 我选择的是

    速度最快的ssd_mobilenet_v1,下载后解压可找到3个含有ckpt的文件,如图

      之后还需下载并配置model对应的config文件(https://github.com/tensorflow/models/tree/master/research/object_detection/samples/configs)并修改文件中的内容

    需要修改的地方有:

    1. num_classes: 改为pbtxt中类的数目
    2. PATH_TO_BE_CONFIGURED的部分改为相应的目录
    3. num_steps定义了学习的上限 默认是200000 可自己更改,训练过程中也可以随时停止
    • 上传文件并在Google Cloud Platform中训练

      1.上传3个ckpt文件以及config文件和.record文件 

          到google cloud控制台-存储目录下,创建存储分区(这里使用takefetter_hand_detector),并新建data文件夹,拖拽上传到该目录中完成后的目录和文件如下

    + takefetter_hand_detector/
      + data/
        - ssd_mobilenet_v1_hand.config
        - model.ckpt.index
        - model.ckpt.meta
        - model.ckpt.data-00000-of-00001
        - hand_label_map.pbtxt
        - hand.record

      2. 打包tf slim和object detection

         在research目录下运行

    python setup.py sdist
    (cd slim && python setup.py sdist)

      3.创建机器学习任务

        在research目录下运行此命令 开始训练

    gcloud ml-engine jobs submit training `whoami`_object_detection_`date +%s` 
        --runtime-version 1.4 
        --job-dir=gs://takefetter_hand_detector/train 
        --packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz 
        --module-name object_detection.train 
        --region us-central1 
        --config object_detection/samples/cloud/cloud.yml 
        -- 
        --train_dir=gs://takefetter_hand_detector/train 
        --pipeline_config_path=gs://takefetter_hand_detector/data/ssd_mobilenet_v1_hand.config

    需要注意的地方有 

    1. windows下需要放在同一行运行 并删除
    2. cloud.yml文件中的内容可以自行更改,我这里的设置为
      trainingInput:
        runtimeVersion: "1.4"
        scaleTier: CUSTOM
        masterType: standard_gpu
        workerCount: 2
        workerType: standard_gpu
        parameterServerCount: 2
        parameterServerType: standard

    在提交任务后在 机器学习引擎-作业中即可看到具体情况,每运行几千次后在 takefetter_hand_detector/train中存储对应cheakpoint的文件 如图

    之后下载需要的cheak的3个文件 复制到research目录下(这里用30045的3个文件作为示例),并将research/object_detectIon目录下的export_inference_graph.py复制到research目录下 运行例如

    python object_detection/export_inference_graph.py 
        --input_type image_tensor 
        --pipeline_config_path object_detection/samples/configs/ssd_mobilenet_v1_hand.config 
        --trained_checkpoint_prefix model.ckpt-30045 
        --output_directory exported_graphs

    在运行完成后research目录中会生成文件夹exported_graphs_30045 包含的文件如图所示

    拷贝frozen_inference_graph.pb和pbtxt文件到test/hand_inference_graph文件夹,并运行hand_detector.py 即可得到如文章开头的结果

    后记:

    1.如果需要视频实时的hand tracking,可使用https://github.com/victordibia/handtracking 在我的渣本上FPS太低了......

    2.我目前使用的数据集还是较小训练次数也比较少,很容易出现一些误识别的情况,之后还会加大数据集和训练次数

    3.换用其他model应该也会显著改善识别精确度

    4.遇到任何问题,欢迎提问(虽然感觉大多数stack overflow都有)

    5.本地训练要好很多,如果使用在Google Cloud训练中可能会遇到问题,但是解决方法是将tensorflow版本改为1.2,但是1.2版本的object detection在准备阶段就会遇到问题,目前来看确实无解。(毕竟API Caller)

    6.本地训练建议使用tensorflow版本为1.2

    感谢:

    1. https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pets.md
    2. https://github.com/victordibia/handtracking
    3. https://pythonprogramming.net/testing-custom-object-detector-tensorflow-object-detection-api-tutorial/?completed=/training-custom-objects-tensorflow-object-detection-api-tutorial/
    4. https://github.com/datitran/raccoon_dataset
    5. https://www.mutah.edu.jo/biometrix/hand-images-databases.html
  • 相关阅读:
    Redis Cluster 剔除节点失败
    redis cluster 常用操作
    pika版本特性研究
    ueditor的集成
    pyhon类
    python之eval简述
    Python:list,tuple
    Python函数式编程学习:lambda, map, reduce, filter、sorted()、lambda、decorator
    Python中字典详解
    Python调用(运行)外部程序
  • 原文地址:https://www.cnblogs.com/take-fetter/p/8438747.html
Copyright © 2011-2022 走看看