zoukankan      html  css  js  c++  java
  • libtorch1.7.0 cuda10.1 进行unet 模型部署

     

     

    #include <iostream>
    #include <memory>
    #include <string>
    #include <torch/script.h>
    #include <opencv2/opencv.hpp>
    #include <opencv2/core/core.hpp>
    #include <opencv2/imgproc/imgproc.hpp>
    #include "opencv2/imgproc/types_c.h"
    
    using namespace std;
    using namespace cv;
    
    torch::Tensor unet_data_preprocess(Mat &image, float scale=1) {
        cv::cvtColor(image, image, CV_BGR2RGB);
    
        int w = image.cols;
        int h = image.rows;
        int newW = int(scale * w);
        int newH = int(scale * h);
        
        Mat img_processed;
        cv::resize(image, img_processed, cv::Size(newW, newH));
    
        //cv::imshow("img_processed", img_processed);
        //cv::waitKey(0);
    
        torch::Tensor imgtransform;
        imgtransform = torch::from_blob(img_processed.data, {1,newH,newW,3}, torch::kByte);
        imgtransform = imgtransform.permute({0,3,1,2 });
        imgtransform = imgtransform.to(torch::kFloat);
        imgtransform = imgtransform.div(255.0);
    
        return imgtransform;
    }
    
    int run_unet() {
        //Load model.
        torch::jit::script::Module unet_module;
    
        try {
            unet_module = torch::jit::load("G:\liu_projects\unet_cpp\Unet_package\model\traced_unet_model.pt");
        }
        catch (const c10::Error& e) {
            std::cerr << "error loading the model!";
            return -1;
        }
    
        torch::Device device(torch::kCUDA);
        unet_module.to(device);
        unet_module.eval();
        std::cout << "model loaded on cuda!
    ";
    
        //prepare image tensor.
        //std::vector<torch::jit::IValue> inputs;
        cv::Mat image;
        image = cv::imread("G:\liu_projects\unet_cpp\Unet_package\test_imgs\test.jpg");
        torch::Tensor img_tensor = unet_data_preprocess(image, 0.5);
        img_tensor = img_tensor.to(device);
    
        //forward.
        at::Tensor output = unet_module.forward({ img_tensor }).toTensor();
        at::Tensor probs = torch::sigmoid(output);
        probs = probs.squeeze(0).detach().permute({ 1, 2, 0 });
        cout << "probs size: " << probs.sizes() << endl;
        probs = probs > 0.5;
        probs = probs.mul(255).clamp(0, 255).to(torch::kU8);
        probs = probs.to(torch::kCPU);
    
        //cv::Mat resultImg(640, 959, CV_8UC1);
        //// copy the data from out_tensor to resultImg
        //std::memcpy((void*)resultImg.data, probs.data_ptr(), sizeof(torch::kU8) * probs.numel());
    
        cv::Mat resultImg(640, 959, CV_8UC1, (uchar*)probs.data_ptr());
    
        cv::imshow("resultImg", resultImg);
        cv::waitKey(0);
    
        return 0;
    }
    
    int main()
    {   
        run_unet();    
        return 0;
    }
    

      

    参考文章:

    https://www.cnblogs.com/yanghailin/p/12901586.html (libtorch 常用api函数示例(史上最全、最详细))

    https://pytorch.apachecn.org/docs/1.4/30.html (记住:模型要保存成cpu的!)

    https://blog.csdn.net/juluwangriyue/article/details/108360320 (libtorch  tensor转mat)

  • 相关阅读:
    event 事件 键盘控制div移动
    event 事件 div鼠标跟随
    获取坐标封装 getPos
    event 事件 clientX 和clientY 配合scrollTop使用, div跟着鼠标走
    event 事件 冒泡
    event 事件 坐标兼容
    event事件基础 document
    DOM 多字符搜索
    DOM search table 模糊搜索
    Reverse a sentence
  • 原文地址:https://www.cnblogs.com/liutianrui1/p/13921213.html
Copyright © 2011-2022 走看看