zoukankan      html  css  js  c++  java
  • tensorflow C++接口调用目标检测pb模型代码

    #include <iostream>
     
    #include "tensorflow/cc/ops/const_op.h"
    #include "tensorflow/cc/ops/image_ops.h"
    #include "tensorflow/cc/ops/standard_ops.h"
    #include "tensorflow/core/framework/graph.pb.h"
    #include "tensorflow/core/framework/tensor.h"
    #include "tensorflow/core/graph/default_device.h"
    #include "tensorflow/core/graph/graph_def_builder.h"
    #include "tensorflow/core/lib/core/errors.h"
    #include "tensorflow/core/lib/core/stringpiece.h"
    #include "tensorflow/core/lib/core/threadpool.h"
    #include "tensorflow/core/lib/io/path.h"
    #include "tensorflow/core/lib/strings/stringprintf.h"
    #include "tensorflow/core/platform/env.h"
    #include "tensorflow/core/platform/init_main.h"
    #include "tensorflow/core/platform/logging.h"
    #include "tensorflow/core/platform/types.h"
    #include "tensorflow/core/public/session.h"
    #include "tensorflow/core/util/command_line_flags.h"
     
    #include <opencv2/opencv.hpp>
    #include <cv.h>
    #include <highgui.h>
    #include <Eigen/Core>
    #include <Eigen/Dense>
     
    using namespace std;
    using namespace cv;
    using namespace tensorflow;
     
     
     
    // 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,
    // 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口
    // 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为
    // Tensor了
    void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols)
    {
        //imshow("input image",img);
        //图像进行resize处理
        resize(img,img,cv::Size(input_cols,input_rows));
        //imshow("resized image",img);
     
        //归一化
        img.convertTo(img,CV_8UC3);  // CV_32FC3
        //img=1-img/255;
     
        //创建一个指向tensor的内容的指针
        uint8 *p = output_tensor->flat<uint8>().data();
     
        //创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值
        cv::Mat tempMat(input_rows, input_cols, CV_8UC3, p);
        img.convertTo(tempMat,CV_8UC3);
     
     //    waitKey(0);
     
    }
     
    int main()
    {
        /*--------------------------------配置关键信息------------------------------*/
        string model_path="../model/coco.pb";
        string image_path="../test.jpg";
        int input_height = 1000;
        int input_width = 1000;
        string input_tensor_name="image_tensor";
        vector<string> out_put_nodes;  //注意,在object detection中输出的三个节点名称为以下三个
        out_put_nodes.push_back("detection_scores");  //detection_scores  detection_classes  detection_boxes
        out_put_nodes.push_back("detection_classes");
        out_put_nodes.push_back("detection_boxes");
     
        /*--------------------------------创建session------------------------------*/
        Session* session;
        Status status = NewSession(SessionOptions(), &session);//创建新会话Session
     
        /*--------------------------------从pb文件中读取模型--------------------------------*/
        GraphDef graphdef; //Graph Definition for current model
     
        Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
        if (!status_load.ok()) {
            cout << "ERROR: Loading model failed..." << model_path << std::endl;
            cout << status_load.ToString() << "
    ";
            return -1;
        }
        Status status_create = session->Create(graphdef); //将模型导入会话Session中;
        if (!status_create.ok()) {
            cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
            return -1;
        }
        cout << "<----Successfully created session and load graph.------->"<< endl;
     
        /*---------------------------------载入测试图片-------------------------------------*/
        cout<<endl<<"<------------loading test_image-------------->"<<endl;
        Mat img;
        img = imread(image_path);
        cvtColor(img, img, CV_BGR2RGB);
        if(img.empty())
        {
            cout<<"can't open the image!!!!!!!"<<endl;
            return -1;
        }
     
        //创建一个tensor作为输入网络的接口
        Tensor resized_tensor(DT_UINT8, TensorShape({1,input_height,input_width,3})); //DT_FLOAT
     
        //将Opencv的Mat格式的图片存入tensor
        CVMat_to_Tensor(img,&resized_tensor,input_height,input_width);
     
        cout << resized_tensor.DebugString()<<endl;
     
        /*-----------------------------------用网络进行测试-----------------------------------------*/
        cout<<endl<<"<-------------Running the model with test_image--------------->"<<endl;
        //前向运行,输出结果一定是一个tensor的vector
        vector<tensorflow::Tensor> outputs;
     
        Status status_run = session->Run({{input_tensor_name, resized_tensor}}, {out_put_nodes}, {}, &outputs);
     
        if (!status_run.ok()) {
            cout << "ERROR: RUN failed..."  << std::endl;
            cout << status_run.ToString() << "
    ";
            return -1;
        }
     
        //把输出值给提取出
        cout << "Output tensor size:" << outputs.size() << std::endl;  //3
        for (int i = 0; i < outputs.size(); i++)
        {
            cout << outputs[i].DebugString()<<endl;   // [1, 50], [1, 50], [1, 50, 4]
        }
     
        cvtColor(img, img, CV_RGB2BGR);  // opencv读入的是BGR格式输入网络前转为RGB
        resize(img,img,cv::Size(1000,1000));  // 模型输入图像大小
        int pre_num = outputs[0].dim_size(1);  // 50  模型预测的目标数量
        auto tmap_pro = outputs[0].tensor<float, 2>();  //第一个是score输出shape为[1,50]
        auto tmap_clas = outputs[1].tensor<float, 2>();  //第二个是class输出shape为[1,50]
        auto tmap_coor = outputs[2].tensor<float, 3>();  //第三个是coordinate输出shape为[1,50,4]
        float probability = 0.5;  //自己设定的score阈值
        for (int pre_i = 0; pre_i < pre_num; pre_i++)
        {
            if (tmap_pro(0, pre_i) < probability)
            {
                break;
            }
            cout << "Class ID: " << tmap_clas(0, pre_i) << endl;
            cout << "Probability: " << tmap_pro(0, pre_i) << endl;
            string id = to_string(int(tmap_clas(0, pre_i)));
            int xmin = int(tmap_coor(0, pre_i, 1) * input_width);
            int ymin = int(tmap_coor(0, pre_i, 0) * input_height);
            int xmax = int(tmap_coor(0, pre_i, 3) * input_width);
            int ymax = int(tmap_coor(0, pre_i, 2) * input_height);
            cout << "Xmin is: " << xmin << endl;
            cout << "Ymin is: " << ymin << endl;
            cout << "Xmax is: " << xmax << endl;
            cout << "Ymax is: " << ymax << endl;
            rectangle(img, cvPoint(xmin, ymin), cvPoint(xmax, ymax), Scalar(255, 0, 0), 1, 1, 0);
            putText(img, id, cvPoint(xmin, ymin), FONT_HERSHEY_COMPLEX, 1.0, Scalar(255,0,0), 1);
        }
        imshow("1", img);
        cvWaitKey(0);
     
        return 0;
    }

    CMakeLists.txt内容如下

    cmake_minimum_required(VERSION 3.0.0)
    project(tensorflow_cpp)
     
    set(CMAKE_CXX_STANDARD 11)
     
    find_package(OpenCV 3.0 QUIET)
    if(NOT OpenCV_FOUND)
        find_package(OpenCV 2.4.3 QUIET)
        if(NOT OpenCV_FOUND)
            message(FATAL_ERROR "OpenCV > 2.4.3 not found.")
        endif()
    endif()
     
    set(TENSORFLOW_INCLUDES
            /usr/local/include/tf/
            /usr/local/include/tf/bazel-genfiles
            /usr/local/include/tf/tensorflow/
            /usr/local/include/tf/tensorflow/third_party)
     
    set(TENSORFLOW_LIBS
            /usr/local/lib/libtensorflow_cc.so
            /usr/local/lib//libtensorflow_framework.so)
     
     
    include_directories(
            ${TENSORFLOW_INCLUDES}
            ${PROJECT_SOURCE_DIR}/third_party/eigen3
    )
    add_executable(predict predict.cpp)
    target_link_libraries(predict
            ${TENSORFLOW_LIBS}
            ${OpenCV_LIBS}
            )

    目录结构如图所示

  • 相关阅读:
    1024X768大图 (Wallpaper)
    (Mike Lynch)Application of linear weight neural networks to recognition of hand print characters
    瞬间模糊搜索1000万基本句型的语言算法
    单核与双核的竞争 INTEL P4 670对抗820
    FlashFTP工具的自动缓存服务器目录的功能
    LDAP over SSL (LDAPS) Certificate
    Restart the domain controller in Directory Services Restore Mode Remotely
    How do I install Active Directory on my Windows Server 2003 server?
    指针与指针变量(转)
    How to enable LDAP over SSL with a thirdparty certification authority
  • 原文地址:https://www.cnblogs.com/cnugis/p/11506767.html
Copyright © 2011-2022 走看看