zoukankan      html  css  js  c++  java
  • tensorflow c++ API加载.pb模型文件并预测图片

    tensorflow  python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测

      1 #include <iostream>
      2 #include <map>
      3 
      4 #include "tensorflow/cc/ops/image_ops.h"
      5 #include "tensorflow/cc/ops/standard_ops.h"
      6 #include "tensorflow/core/framework/graph.pb.h"
      7 #include "tensorflow/core/framework/tensor.h"
      8 #include "tensorflow/core/graph/default_device.h"
      9 #include "tensorflow/core/graph/graph_def_builder.h"
     10 #include "tensorflow/core/platform/logging.h"
     11 #include "tensorflow/core/platform/types.h"
     12 #include "tensorflow/core/public/session.h"
     13 
     14 using namespace std ;
     15 using namespace tensorflow;
     16 using tensorflow::Tensor;
     17 using tensorflow::Status;
     18 using tensorflow::string;
     19 using tensorflow::int32;
     20 
     21 
     22 //从文件名中读取数据
     23 Status ReadTensorFromImageFile(string file_name, const int input_height,
     24                                const int input_width,
     25                                vector<Tensor>* out_tensors) {
     26     auto root = Scope::NewRootScope();
     27     using namespace ops;
     28 
     29     auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name);
     30     const int wanted_channels = 1;
     31     Output image_reader;
     32     std::size_t found = file_name.find(".png");
     33     //判断文件格式
     34     if (found!=std::string::npos) {
     35         image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels));
     36     } 
     37     else {
     38         image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels));
     39     }
     40     // 下面几步是读取图片并处理
     41     auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT);
     42     auto dims_expander = ExpandDims(root, float_caster, 0);
     43     auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width}));
     44     // Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std});
     45     Transpose(root.WithOpName("transpose"),resized,{0,2,1,3});
     46 
     47     GraphDef graph;
     48     root.ToGraphDef(&graph);
     49 
     50     unique_ptr<Session> session(NewSession(SessionOptions()));
     51     session->Create(graph);
     52     session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中
     53 
     54     return Status::OK();
     55 }
     56 
     57 int main(int argc, char* argv[]) {
     58 
     59     string graph_path = "aov_crnn.pb";
     60     GraphDef graph_def;
     61     //读取模型文件
     62     if (!ReadBinaryProto(Env::Default(), graph_path, &graph_def).ok()) {
     63         cout << "Read model .pb failed"<<endl;
     64         return -1;
     65     }
     66 
     67     //新建session
     68     unique_ptr<Session> session;
     69     SessionOptions sess_opt;
     70     sess_opt.config.mutable_gpu_options()->set_allow_growth(true);
     71     (&session)->reset(NewSession(sess_opt));
     72     if (!session->Create(graph_def).ok()) {
     73         cout<<"Create graph failed"<<endl;
     74         return -1;
     75     }
     76 
     77     //读取图像到inputs中
     78     int input_height = 40;
     79     int input_width = 240;
     80     vector<Tensor> inputs;
     81     // string image_path(argv[1]);
     82     string image_path("test.jpg");
     83     if (!ReadTensorFromImageFile(image_path, input_height, input_width,&inputs).ok()) {
     84         cout<<"Read image file failed"<<endl;
     85         return -1;
     86     }
     87 
     88     vector<Tensor> outputs;
     89     string input = "inputs_sq";
     90     string output = "results_sq";//graph中的输入节点和输出节点,需要预先知道
     91 
     92     pair<string,Tensor>img(input,inputs[0]);
     93     Status status = session->Run({img},{output}, {}, &outputs);//Run,得到运行结果,存到outputs中
     94     if (!status.ok()) {
     95         cout<<"Running model failed"<<endl;
     96         cout<<status.ToString()<<endl;
     97         return -1;
     98     }
     99 
    100 
    101     //得到模型运行结果
    102     Tensor t = outputs[0];        
    103     auto tmap = t.tensor<int64, 2>(); 
    104     int output_dim = t.shape().dim_size(1); 
    105 
    106 
    107     return 0;
    108 }
    g++ -g  tf_predict.cpp -o tf_predict -I /usr/include/eigen3 -I /usr/local/include/tf  -L/usr/local/lib/ `pkg-config --cflags --libs protobuf`  -ltensorflow_cc  -ltensorflow_framework

     也可以用opencv c++库读取图片Mat复制到Tensor中

     1 tensorflow::Tensor readTensor(string filename){
     2     tensorflow::Tensor input_tensor(DT_FLOAT,TensorShape({1,240,40,1}));
     3     Mat src=imread(filename,0);
     4     Mat dst;
     5     resize(src,dst,Size(240,40));//resize
     6     Mat dst_transpose=dst.t();//transpose
     7 
     8     auto tmap=input_tensor.tensor<float,4>();
     9 
    10     for(int i=0;i<240;i++){//Mat复制到Tensor
    11         for(int j=0;j<40;j++){
    12             tmap(0,i,j,0)=dst_transpose.at<uchar>(i,j);
    13         }
    14     }
    15 
    16     return input_tensor;
    17 }

     也可用指针引用的方式转换

    1             tensorflow::Tensor input_tensor(DT_FLOAT,TensorShape({1,height,width,3}));
    2         float *tensor_data_ptr = input_tensor.flat<float>().data();              
    3         cv::Mat fake_mat(dst.rows, dst.cols, CV_32FC(src.channels()), tensor_data_ptr); 
    4         dst.convertTo(fake_mat, CV_32FC3);
  • 相关阅读:
    nodejs MYSQL数据库执行多表查询
    【BZOJ3994】[SDOI2015]约数个数和 莫比乌斯反演
    【BZOJ2693】jzptab 莫比乌斯反演
    【BZOJ2154】Crash的数字表格 莫比乌斯反演
    【BZOJ2242】[SDOI2011]计算器 BSGS
    【BZOJ2005】[Noi2010]能量采集 欧拉函数
    【BZOJ1408】[Noi2002]Robot DP+数学
    【BZOJ2045】双亲数 莫比乌斯反演
    【BZOJ2186】[Sdoi2008]沙拉公主的困惑 线性筛素数
    【BZOJ4176】Lucas的数论 莫比乌斯反演
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/10412967.html
Copyright © 2011-2022 走看看