zoukankan      html  css  js  c++  java
  • tensorflow c++ demo

    使用bazel编译了tensorflow1.13.1,还差一个demo测试,在网上找了一个例程,但是不全,自己辛苦补全了,供大家参考学习。

    #define COMPILER_MSVC
    #define NOMINMAX
    #define PLATFORM_WINDOWS   // 指定使用tensorflow/core/platform/windows/cpu_info.h
    
    #include<iostream>
    #include<opencv2/opencv.hpp>
    #include"tensorflow/core/public/session.h"
    #include "tensorflow/core/platform/env.h"
    #include <time.h>
    #include <vector>
    #include <string.h>
    using namespace tensorflow;
    using namespace cv;
    using std::cout;
    using std::endl;
    
    int main() {
        const std::string model_path = "frozen_inference_graph.pb";// tensorflow模型文件,注意不能含有中文
        const std::string image_path = "image1.jpg";    // 待inference的图片grace_hopper.jpg
    
                                                        // 设置输入图像
        cv::Mat img = cv::imread(image_path);
        //cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
        int height = img.rows;
        int width = img.cols;
        int depth = img.channels();
    
        // 取图像数据,赋给tensorflow支持的Tensor变量中
        tensorflow::Tensor input_tensor(DT_UINT8, TensorShape({ 1, height, width, depth }));
        const uint8* source_data = img.data;
        auto input_tensor_mapped = input_tensor.tensor<uint8, 4>();
    
        for (int i = 0; i < height; i++) {
            const uint8* source_row = source_data + (i * width * depth);
            for (int j = 0; j < width; j++) {
                const uint8* source_pixel = source_row + (j * depth);
                for (int c = 0; c < depth; c++) {
                    const uint8* source_value = source_pixel + c;
                    input_tensor_mapped(0, i, j, c) = *source_value;
                }
            }
        }
    
        // 初始化tensorflow session
        Session* session;
        Status status = NewSession(SessionOptions(), &session);
        if (!status.ok()) {
            std::cerr << status.ToString() << endl;
            return -1;
        }
        else {
            cout << "Session created successfully" << endl;
        }
    
        // 读取二进制的模型文件到graph中
        tensorflow::GraphDef graph_def;
        status = ReadBinaryProto(Env::Default(), model_path, &graph_def);
        if (!status.ok()) {
            std::cerr << status.ToString() << endl;
            return -1;
        }
        else {
            cout << "Load graph protobuf successfully" << endl;
        }
    
        // 将graph加载到session
        status = session->Create(graph_def);
        if (!status.ok()) {
            std::cerr << status.ToString() << endl;
            return -1;
        }
        else {
            cout << "Add graph to session successfully" << endl;
        }
        // 输入inputs,“ x_input”是我在模型中定义的输入数据名称
        std::vector<std::pair<std::string, tensorflow::Tensor>> inputs = {
            { "image_tensor:0", input_tensor },
        };
    
        // 输出outputs
        std::vector<tensorflow::Tensor> outputs;
    
        //批处理识别
        double start = clock();
        std::vector<std::string> output_nodes;
        output_nodes.push_back("num_detections");
        output_nodes.push_back("detection_boxes");
        output_nodes.push_back("detection_scores");
        output_nodes.push_back("detection_classes");
        // 运行会话,最终结果保存在outputs中
        status = session->Run(inputs, { output_nodes }, {}, &outputs);
        if (!status.ok()) {
            std::cerr << status.ToString() << endl;
            return -1;
        }
        else {
            cout << "Run session successfully" << endl;
        }
        std::vector<float> vecfldata;
        std::vector<float> vecflprob;
        for (int i = 0; i < outputs.size(); i++)
        {
            Tensor t = outputs[i]; // 从节点取出第一个输出 "node:0"        
            cout << t.dtype() << std::endl;
            TensorShape shape = t.shape();
            int dim = shape.dims();
            cout << dim << std::endl;
            cout << shape.num_elements() << std::endl;
            std::vector<int> vecsize;
            for (int d = 0; d < shape.dims(); d++)
            {
                int size = shape.dim_size(d);
                cout << size << endl;
                vecsize.push_back(size);
            }        
            if (dim == 3)
            {
                auto tmap = t.tensor<float, 3>();//这里<float, 3>的3是根据dim=3来的
                for (int l = 0; l < vecsize[0]; l++)
                {
                    for (int m =0; m < vecsize[1]; m++)
                    {
                        for (int n =0; n<vecsize[2];n++)
                        {
                            vecfldata.push_back(tmap(l,m,n));
                        }
                    }
                }        
                
            }
            if (i==2)
            {
                auto tmap = t.tensor<float, 2>();
                for (int p = 0; p < shape.dim_size(1); p++)
                {
                    vecflprob.push_back(tmap( 0, p));
                }
            }
    
        }
        for (int k = 0; k < 2; k++)
        {
            int lty = height*vecfldata[4 * k];
            int ltx = width*vecfldata[4 * k + 1];
            int rby = height*vecfldata[4 * k + 2];
            int rbx = width*vecfldata[4 * k + 3];
            cv::rectangle(img, cv::Point(ltx, lty), cv::Point(rbx, rby), Scalar(255, 0, 0));
            cv::putText(img, std::to_string(vecflprob[k]), Point(ltx, lty), FONT_HERSHEY_SCRIPT_SIMPLEX, 1.0, Scalar(12, 255, 200), 1, 8);
        }
        std::cout << "outputs[0] num_detections" << outputs[0].DebugString() << std::endl;
        std::cout << "outputs[1] detection_boxes" << outputs[1].DebugString() << std::endl;
        std::cout << "outputs[2] detection_scores" << outputs[2].DebugString() << std::endl;
        std::cout << "outputs[3] detection_classes" << outputs[3].DebugString() << std::endl;
    
        double    finish = clock();
        double duration = (double)(finish - start) / CLOCKS_PER_SEC;
        cout << "spend time:" << duration << endl;
        cv::imshow("image", img);
        cv::waitKey();
        return 0;
    }

    运行结果:

  • 相关阅读:
    业务场景和业务用例场景的区别(作者:Arthur网友)
    svn 安装
    PHP has encountered an Access Violation at
    邀请大象一书的读者和广大网友写关于分析设计、建模方面的自愿者文章
    手机网页 复制信息方法 免费短信
    delphi Inno Setup 制作安装程序
    Special Folders
    Windows mobile上获取输入光标位置
    加壳程序无法准确读输入表的解决办法
    C++ PostMessage 模拟键盘鼠标
  • 原文地址:https://www.cnblogs.com/juluwangshier/p/13280965.html
Copyright © 2011-2022 走看看