zoukankan      html  css  js  c++  java
  • Tensorflow lite Android 人脸检测demo

    https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md

    Tensorflow及Object detection API相关环境的搭建安装

    https://www.jianshu.com/p/286b8163da29

    Bazel安装

    以上步骤用于Tensorflow物体检测模型的训练及Tensorflow到Tensorflow lite的模型转换,具体步骤后面再讲。

    下载Android Studio 导入Tensorflow目录下tflite Android demo。具体目录在Tensorflow/contrib/lite/example下

    下载相关Jar包。。进行编译看看是否能够编译成功,一般来说网络良好,自动下载好各种包后就会编译成功,测试生成的apk,默认的是物体检测。要进行人脸检测,这里需要做的就是把相关模型进行替换。

    Demo中的物体检测模型是基于Tensorflow的ssd-mobilenet-quantized模型,此模型是在coco数据集上训练。我们可以使用此模型做迁移学习来得到针对于人脸检测的模型。

    预训练模型可以在tensorflow object detection的model zoo中下载。

    人脸数据集可以采用WIDER FACE数据集,下载好后利用脚本将图像及标注信息转换为tfrecord格式供训练使用。

    从research路径找到对应模型的config文件object_detection/samples/configs/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config

    修改其中tfrecord及label的路径基于checkpoint路径后开始训练模型

    python train.py 
            --logtostderr 
            --train_dir=/home/kai/tensorflow/face/ 
            --pipeline_config_path=/home/kai/tensorflow/face/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config

    待loss足够小时终止训练,得到checkpoint。至此我们得到了可以做人脸检测的模型,但想要在移动端使用tensorflow lite模型还需要一些额外的工作。

    Tensorflow Lite是Google设计一种针对移动端的轻量级深度学习模型,它使用quantized kernel等一系列技术使模型更轻便,更快速,而更适合在移动端上使用。

    首先需要将checkpoint转换为Tensorflow lite可用的pb文件

    python object_detection/export_tflite_ssd_graph.py 
    --pipeline_config_path=$CONFIG_FILE 
    --trained_checkpoint_prefix=$CHECKPOINT_PATH 
    --output_directory=$OUTPUT_DIR 
    --add_postprocessing_op=true

    确保用的是export_tflite_ssd_graph而不是export_inference_graph否则得到的pb后面无法转换。

     得到tflite_graph.pb后需要利用TOCO将pb模型转换为.tflite模型

    在TensorFlow目录下执行

    bazel run -c opt tensorflow/contrib/lite/toco:toco -- 
    --input_file=$OUTPUT_DIR/tflite_graph.pb 
    --output_file=$OUTPUT_DIR/detect.tflite 
    --input_shapes=1,300,300,3 
    --input_arrays=normalized_input_image_tensor 
    --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'  
    --inference_type=QUANTIZED_UINT8 
    --mean_values=128 
    --std_values=128 
    --change_concat_input_ranges=false 
    --allow_custom_ops

    如没有报错则会在OUTPUT_DIR目录下生产一个detect.tflite文件即为tflite模型

    在TensorFlow lite demo中添加模型

    /*
     * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
     *
     * Licensed under the Apache License, Version 2.0 (the "License");
     * you may not use this file except in compliance with the License.
     * You may obtain a copy of the License at
     *
     *       http://www.apache.org/licenses/LICENSE-2.0
     *
     * Unless required by applicable law or agreed to in writing, software
     * distributed under the License is distributed on an "AS IS" BASIS,
     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     * See the License for the specific language governing permissions and
     * limitations under the License.
     */
    
    package org.tensorflow.demo;
    
    import android.graphics.Bitmap;
    import android.graphics.Bitmap.Config;
    import android.graphics.Canvas;
    import android.graphics.Color;
    import android.graphics.Matrix;
    import android.graphics.Paint;
    import android.graphics.Paint.Style;
    import android.graphics.RectF;
    import android.graphics.Typeface;
    import android.media.ImageReader.OnImageAvailableListener;
    import android.os.SystemClock;
    import android.util.Size;
    import android.util.TypedValue;
    import android.widget.Toast;
    import java.io.IOException;
    import java.util.LinkedList;
    import java.util.List;
    import java.util.Vector;
    import org.tensorflow.demo.OverlayView.DrawCallback;
    import org.tensorflow.demo.env.BorderedText;
    import org.tensorflow.demo.env.ImageUtils;
    import org.tensorflow.demo.env.Logger;
    import org.tensorflow.demo.tracking.MultiBoxTracker;
    import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds.
    
    /**
     * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track
     * objects.
     */
    public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
      private static final Logger LOGGER = new Logger();
    
      // Configuration values for the prepackaged SSD face model.
      private static final int TF_OD_API_INPUT_SIZE = 300;
      private static final boolean TF_OD_API_IS_QUANTIZED = true;
      private static final String TF_OD_API_MODEL_FILE = "facedetect.tflite";
      private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/face.txt";
    
    
      // Configuration values for the prepackaged SSD Normal model.
      private static final int TF_OD_API_INPUT_SIZE_N = 300;
      private static final boolean TF_OD_API_IS_QUANTIZED_N = true;
      private static final String TF_OD_API_MODEL_FILE_N = "detect.tflite";
      private static final String TF_OD_API_LABELS_FILE_N = "file:///android_asset/coco_labels_list.txt";
      
      // Which detection model to use: by default uses Tensorflow Object Detection API frozen
      // checkpoints.
      private enum DetectorMode {
        TF_OD_API;
      }
    
      private static final DetectorMode MODE = DetectorMode.TF_OD_API;
    
      // Minimum detection confidence to track a detection.
      private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.4f;
    
      private static final boolean MAINTAIN_ASPECT = false;
    
      private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
    
      private static final boolean SAVE_PREVIEW_BITMAP = false;
      private static final float TEXT_SIZE_DIP = 10;
    
      private Integer sensorOrientation;
    
      // face detector
      private Classifier detector;
      // object detector
      private Classifier detector_n;
    
      private long lastProcessingTimeMs;
      private Bitmap rgbFrameBitmap = null;
      private Bitmap croppedBitmap = null;
      private Bitmap cropCopyBitmap = null;
    
      private boolean computingDetection = false;
    
      private long timestamp = 0;
    
      private Matrix frameToCropTransform;
      private Matrix cropToFrameTransform;
    
      //tracker
      private MultiBoxTracker tracker;
    
      private byte[] luminanceCopy;
    
      private BorderedText borderedText;
      @Override
      public void onPreviewSizeChosen(final Size size, final int rotation) {
        final float textSizePx =
            TypedValue.applyDimension(
                TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
        borderedText = new BorderedText(textSizePx);
        borderedText.setTypeface(Typeface.MONOSPACE);
    
    
        tracker = new MultiBoxTracker(this);
    
        int cropSize = TF_OD_API_INPUT_SIZE;
    
        // face detector
        try {
          detector =
              TFLiteObjectDetectionAPIModel.create(
                  getAssets(),
                  TF_OD_API_MODEL_FILE,
                  TF_OD_API_LABELS_FILE,
                  TF_OD_API_INPUT_SIZE,
                  TF_OD_API_IS_QUANTIZED);
          cropSize = TF_OD_API_INPUT_SIZE;
        } catch (final IOException e) {
          LOGGER.e("Exception initializing classifier!", e);
          Toast toast =
              Toast.makeText(
                  getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
          toast.show();
          finish();
        }
    
        // Normal object detector
        try {
          detector_n =
                  TFLiteObjectDetectionAPIModel.create(
                          getAssets(),
                          TF_OD_API_MODEL_FILE_N,
                          TF_OD_API_LABELS_FILE_N,
                          TF_OD_API_INPUT_SIZE_N,
                          TF_OD_API_IS_QUANTIZED_N);
          cropSize = TF_OD_API_INPUT_SIZE;
        } catch (final IOException e) {
          LOGGER.e("Exception initializing classifier!", e);
          Toast toast =
                  Toast.makeText(
                          getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
          toast.show();
          finish();
        }
    
    
    
        previewWidth = size.getWidth();
        previewHeight = size.getHeight();
    
        sensorOrientation = rotation - getScreenOrientation();
        LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
    
        LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
        rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
        croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
    
        frameToCropTransform =
            ImageUtils.getTransformationMatrix(
                previewWidth, previewHeight,
                cropSize, cropSize,
                sensorOrientation, MAINTAIN_ASPECT);
    
        cropToFrameTransform = new Matrix();
        frameToCropTransform.invert(cropToFrameTransform);
    
        trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
        trackingOverlay.addCallback(
            new DrawCallback() {
              @Override
              public void drawCallback(final Canvas canvas) {
                tracker.draw(canvas);
                if (isDebug()) {
                  //tracker.drawDebug(canvas);
                }
              }
            });
    
        addCallback(
            new DrawCallback() {
              @Override
              public void drawCallback(final Canvas canvas) {
                if (!isDebug()) {
                  return;
                }
                final Bitmap copy = cropCopyBitmap;
                if (copy == null) {
                  return;
                }
    
                final int backgroundColor = Color.argb(100, 0, 0, 0);
                canvas.drawColor(backgroundColor);
    
                final Matrix matrix = new Matrix();
                final float scaleFactor = 2;
                matrix.postScale(scaleFactor, scaleFactor);
                matrix.postTranslate(
                    canvas.getWidth() - copy.getWidth() * scaleFactor,
                    canvas.getHeight() - copy.getHeight() * scaleFactor);
                canvas.drawBitmap(copy, matrix, new Paint());
    
                final Vector<String> lines = new Vector<String>();
                if (detector_n != null) {
                  final String statString = detector_n.getStatString();
                  final String[] statLines = statString.split("
    ");
                  for (final String line : statLines) {
                    lines.add(line);
                  }
                }
                lines.add("");
    
                lines.add("Frame: " + previewWidth + "x" + previewHeight);
                lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
                lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
                lines.add("Rotation: " + sensorOrientation);
                lines.add("Inference time: " + lastProcessingTimeMs + "ms");
    
                borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
              }
            });
      }
    
      OverlayView trackingOverlay;
    
      @Override
      protected void processImage() {
        ++timestamp;
        final long currTimestamp = timestamp;
        byte[] originalLuminance = getLuminance();
        tracker.onFrame(
            previewWidth,
            previewHeight,
            getLuminanceStride(),
            sensorOrientation,
            originalLuminance,
            timestamp);
    
    
        trackingOverlay.postInvalidate();
    
        // No mutex needed as this method is not reentrant.
        if (computingDetection) {
          readyForNextImage();
          return;
        }
        computingDetection = true;
        LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");
    
        rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
    
        if (luminanceCopy == null) {
          luminanceCopy = new byte[originalLuminance.length];
        }
        System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length);
        readyForNextImage();
    
        final Canvas canvas = new Canvas(croppedBitmap);
        canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
        // For examining the actual TF input.
        if (SAVE_PREVIEW_BITMAP) {
          ImageUtils.saveBitmap(croppedBitmap);
        }
    
        runInBackground(
            new Runnable() {
              @Override
              public void run() {
                LOGGER.i("Running detection on image " + currTimestamp);
                final long startTime = SystemClock.uptimeMillis();
                final List<Classifier.Recognition> results_n = detector_n.recognizeImage(croppedBitmap);
                final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap);
                lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
    
                cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
                final Canvas canvas = new Canvas(cropCopyBitmap);
                final Paint paint = new Paint();
                paint.setColor(Color.RED);
                paint.setStyle(Style.STROKE);
                paint.setStrokeWidth(2.0f);
    
                float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
                switch (MODE) {
                  case TF_OD_API:
                    minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
                    break;
                }
    
                final List<Classifier.Recognition> mappedRecognitions =
                    new LinkedList<Classifier.Recognition>();
    
                final List<Classifier.Recognition> mappedRecognitions_n =
                    new LinkedList<Classifier.Recognition>();
    
                boolean faceflag = true;
    
                for (final Classifier.Recognition result : results) {
                  final RectF location = result.getLocation();
                  if (location != null && result.getConfidence() >= minimumConfidence) {
                    //canvas.drawRect(location, paint);
                    faceflag = false;
                    cropToFrameTransform.mapRect(location);
                    result.setLocation(location);
                    mappedRecognitions.add(result);
                  }
                }
    
                if(faceflag)
                {
                  for (final Classifier.Recognition result_n : results_n) {
                    final RectF location = result_n.getLocation();
                    String temp = result_n.getTitle();
                    if (location != null && result_n.getConfidence() >= minimumConfidence && result_n.getTitle().equals("oven") ) {
                      //canvas.drawRect(location, paint);
    
                      cropToFrameTransform.mapRect(location);
                      result_n.setLocation(location);
                      mappedRecognitions_n.add(result_n);
                    }
                  }
                }
    
    
                tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp);
                tracker.trackResults(mappedRecognitions_n, luminanceCopy, currTimestamp);
                trackingOverlay.postInvalidate();
    
                requestRender();
                computingDetection = false;
              }
            });
      }
    
      @Override
      protected int getLayoutId() {
        return R.layout.camera_connection_fragment_tracking;
      }
    
      @Override
      protected Size getDesiredPreviewFrameSize() {
        return DESIRED_PREVIEW_SIZE;
      }
    
      @Override
      public void onSetDebug(final boolean debug) {
        detector.enableStatLogging(debug);
        detector_n.enableStatLogging(debug);
      }
    }
  • 相关阅读:
    Java实现微生物增殖
    HttpClient学习整理
    在Eclipse中使用JUnit4进行单元测试(初级篇)
    http post提交数组
    postman测试post请求参数为json类型
    【springmvc】传值的几种方式&&postman接口测试
    postman的使用方法详解!最全面的教程
    Gson 使用总结 高级用法
    各个JSON技术的比较(Jackson,Gson,Fastjson)的对比
    Session保存用户名到Session域对象中
  • 原文地址:https://www.cnblogs.com/klitech/p/9844997.html
Copyright © 2011-2022 走看看