zoukankan      html  css  js  c++  java
  • 使用C++调用pytorch模型(Linux)

    前言

    模型转换思路通常为:

    • Pytorch -> ONNX -> TensorRT
    • Pytorch -> ONNX -> TVM
    • Pytorch -> 转换工具 -> caffe
    • Pytorch -> torchscript(C++版本Torch)

    我的模型是使用Pytorch1.0训练的,第三种方法应该是还不支持,没有对应层名字, 放弃. (以下是用方法3生成的网络结构图, 其中部分层名字和工具对应不上).

    因此本文使用第4中方法,详细步骤分两步, 具体如下(目前资料少,坑很多)


    1. pytorch模型转化为libtorch的torchscript模型 (.pth -> .pt)

    首先, 在python中, 把模型转化成.pt文件

    Pytorch官方提供的C++API名为libtorch,详细查看:
    - LIBRARY API
    - USING THE PYTORCH C++ FRONTEND

    import torch
    
    # An instance of your model.
    from my_infer import BaseLine
    
    model = BaseLine().model.cpu().eval()
    
    # An example input you would normally provide to your model's forward() method.
    example = torch.rand(1, 3, 256 , 128)
    
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(model, example)
    traced_script_module.save("demo/model.pt")
    

    2. 使用libtorch调用torchscript模型

    此处有一个大坑, opencv和torch可以单独使用, 但如果链接libtorch库以后, cv::imread提示未定义的应用. 所以使用了opencv2的图片读取方式, 然后再转成cv::Mat.


    更新时间:2019/05/24
    在更换libtorch版本后, cv:imread不再报错, 具体原因说不上来, 应该是之前的版本链接库时候出现矛盾什么的...


    #include <iostream>                                                                                                
    #include "torch/script.h"
    #include "torch/torch.h"
    #include "opencv2/core.hpp"
    #include "opencv2/imgproc.hpp"
    #include "opencv2/highgui.hpp"
    #include <vector>
    
    int main()
    {
        //加载pytorch模型
        std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("/home/zhuoshi/ZSZT/Geoffrey/opencvTest/m
        assert(module != nullptr);
    
        // 创建一个Tensor
        //std::vector<torch::jit::IValue> inputs;
        //inputs.emplace_back(torch::ones({1, 3, 256, 128}));
        //测试前向
        //at::Tensor output = module->forward(inputs).toTensor();
        //std::cout << output;
    
        // 转换为int8类型
        //vector<int16_t> feature(2048);
        //for (int i = 0;i<128;i++)
        //{
        // 转化成Float
        //int temp = output[0][i].item().toInt();
        //    if (temp != 0){
        //        temp = 1;
        //    }
        //    feature[i] = temp;
        //}
        //std::cout << feature;
    
        //读取图片
        IplImage* pmg = cvLoadImage("/home/zhuoshi/ZSZT/Geoffrey/opencvTest/test.jpg");
        cv::Mat image(pmg, true);
        //cv::Mat imageRGB = cv::cvtColor(image, imageRGB, cv::COLOR_BGR2RGB);
        cv::cvtColor(image, image, CV_BGR2RGB);
    
        //IplImage转换成Tensor
        cv::Mat img_float;
        image.convertTo(img_float, CV_32F, 1.0 / 255);
        cv::resize(img_float, img_float, cv::Size(256, 128));
        torch::Tensor tensor_image = torch::from_blob(img_float.data, {1, 3, 256, 128}, torch::kFloat32);
    
        //前向
        std::vector<torch::jit::IValue> input;
        input.emplace_back(tensor_image);
        at::Tensor output_image = module->forward(input).toTensor();
        //std::cout << output_image;
    
        //Tensor 转 array
        std::vector<float> feature(2048);
        for (int i=0; i<2048; i++){
        //    feature[i] = output_image[i]
            std::cout << output_image[0][i].item().toFloat();
        }
        return 0;
    }  
    

    对应的CMakeLists.txt内容:

    cmake_minimum_required(VERSION 2.8)                                                                                
    
    project(opencv_example_project)
    SET(CMAKE_C_COMPILER g++)
    add_definitions(--std=c++11)
    
    # 指定libTorch位置
    set(Torch_DIR /home/zhuoshi/ZSZT/Geoffrey/opencvTest/libtorch/share/cmake/Torch)
    find_package(Torch REQUIRED)
    
    find_package(OpenCV REQUIRED)
    
    message(STATUS "OpenCV library status:")
    message(STATUS "    version: ${OpenCV_VERSION}")
    message(STATUS "    libraries: ${OpenCV_LIBS}")
    message(STATUS "    include path: ${OpenCV_INCLUDE_DIRS}")
    message(STATUS "    torch lib : ${TORCH_LIBRARIES} ")
    
    include_directories(${OpenCV_INCLUDE_DIRS}
                        /home/zhuoshi/ZSZT/Geoffrey/opencvTest/libtorch/include
                        /home/zhuoshi/ZSZT/Geoffrey/opencvTest/libtorch/include/torch/csrc/api/include/
                        )
    
    add_executable(main main.cpp)
         
    # Link your application with OpenCV libraries
    target_link_libraries(main ${OpenCV_LIBS} ${TORCH_LIBRARIES} )
    

    运行结果如图:


    更新时间: 2019/05/25, 更换libtorch版本后, cv::read可用, 这是新版本

    #include <iostream>
    #include "torch/script.h"
    #include "torch/torch.h"
    #include "opencv2/core.hpp"
    #include "opencv2/imgproc.hpp"
    #include "opencv2/highgui.hpp"
    #include "opencv2/imgcodecs.hpp"
    #include <vector>
    
    int main()
    {
        /* 配置参数 */
        std::vector <float> mean_ = {0.485, 0.456, 0.406};
        std::vector <float> std_ = {0.229, 0.224, 0.225};
        char path[] = "../test.jpg";
    
        // 读取图片
        cv::Mat image = cv::imread(path);
        if (image.empty())
            fprintf(stderr, "Can not load image
    ");
    
        // 转换通道,
        cv::cvtColor(image, image, CV_BGR2RGB);
        cv::Mat img_float;
        image.convertTo(img_float, CV_32F, 1.0 / 255);
    
        // resize, 测试一个点数据
        cv::resize(img_float, img_float, cv::Size(256, 128));
        //std::cout << img_float.at<cv::Vec3f>(256, 128)[1] << std::endl;
    
        // 转换成tensor
        auto img_tensor = torch::from_blob(img_float.data, {1, 3, 256, 128}, torch::kFloat32);
        //img_tensor = img_tensor.permute({0,3,1,2});
        // tensor标准化
        for (int i = 0; i < 3; i++) {
            img_tensor[0][0] = img_tensor[0][0].sub_(mean_[i]).div_(std_[i]);
        }
    
        // 构造input
        //auto img_var = torch::autograd::make_variable(img_tensor, false); //tensor->variable会报错
        std::vector<torch::jit::IValue> inputs;
        inputs.emplace_back(img_tensor); //向容器中加入新的元素, 右值引用
    
        //加载pytorch模型
        std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../model/model_int.pt");
        assert(module != nullptr);
    
        //前向
        at::Tensor output_image = module->forward(inputs).toTensor();
        std::cout << output_image;
    
        return 0;
    }
    
    cv::Mat convertTo3Channels(cv::Mat binImg)
    {
        cv::Mat three_channel = cv::Mat::zeros(binImg.rows, binImg.cols, CV_8UC3);
        std::vector<cv::Mat> channels;
        for (int i=0;i<3;i++)
        {
            channels.push_back(binImg);
        }
        merge(channels, three_channel);
        return three_channel;
    }
    

    对应CMakelist.txt文件:

    cmake_minimum_required(VERSION 2.8)
    
    # Define project name
    project(opencv_example_project)
    
    SET(CMAKE_C_COMPILER g++)
    add_definitions(--std=c++11)
    
    # 指定libTorch位置
    set(Torch_DIR /home/geoffrey/CLionProjects/opencvTest/libtorch/share/cmake/Torch)
    find_package(Torch REQUIRED)
    
    message(STATUS "Torch library status:")
    message(STATUS "    version: ${TORCH_VERSION}")
    message(STATUS "    libraries: ${TORCH_LIBS}")
    message(STATUS "    include path: ${TORCH_INCLUDE_DIRS}")
    message(STATUS "    torch lib : ${TORCH_LIBRARIES} ")
    
    # 指定OpenCV位置
    #set(OpenCV_DIR /run/media/geoffrey/Timbersaw/Backup/other_package/opencv-4.0.0/build)
    # set(OpenCV_DIR /opt/opencv2)
    find_package(OpenCV  REQUIRED)
    message(STATUS "OpenCV library status:")
    message(STATUS "    version: ${OpenCV_VERSION}")
    message(STATUS "    libraries: ${OpenCV_LIBS}")
    message(STATUS "    include path: ${OpenCV_INCLUDE_DIRS}")
    message(STATUS "    opencv lib : ${OpenCV_LIBRARIES} ")
    
    # 包含头文件include
    include_directories(${OpenCV_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS})
    
    # 生成的目标文件(可执行文件)
    add_executable(main main.cpp)
    
    # 置需要的库文件lib
    # set(OpenCV_LIBS opencv_core  opencv_highgui opencv_imgcodecs opencv_imgproc)
    target_link_libraries(main  ${OpenCV_LIBS} ${TORCH_LIBRARIES}) #
    
    
    


    参考资料

    1. TorchDemo
    2. 利用Pytorch的C++前端(libtorch)读取预训练权重并进行预测
    3. C++部署pytorch模型(二)————使用libtorch调用torchscripts模型

  • 相关阅读:
    Rraspberry Pi 4B python3 安装opencv
    如何用arduion制作智能 垃圾桶
    MySQL(二)表结构的管理
    MySQL(一)基础操作
    vc++绘图基础
    网站签~
    (转)Oracle 知识日常积累
    利用反射判断bean属性不为空(null和空串)
    (转)Oracle 单字段拆分成多行
    svn 解决树冲突
  • 原文地址:https://www.cnblogs.com/geoffreyone/p/10827010.html
Copyright © 2011-2022 走看看