#include <iostream> #include <memory> #include <string> #include <torch/script.h> #include <opencv2/opencv.hpp> #include <opencv2/core/core.hpp> #include <opencv2/imgproc/imgproc.hpp> #include "opencv2/imgproc/types_c.h" using namespace std; using namespace cv; torch::Tensor unet_data_preprocess(Mat &image, float scale=1) { cv::cvtColor(image, image, CV_BGR2RGB); int w = image.cols; int h = image.rows; int newW = int(scale * w); int newH = int(scale * h); Mat img_processed; cv::resize(image, img_processed, cv::Size(newW, newH)); //cv::imshow("img_processed", img_processed); //cv::waitKey(0); torch::Tensor imgtransform; imgtransform = torch::from_blob(img_processed.data, {1,newH,newW,3}, torch::kByte); imgtransform = imgtransform.permute({0,3,1,2 }); imgtransform = imgtransform.to(torch::kFloat); imgtransform = imgtransform.div(255.0); return imgtransform; } int run_unet() { //Load model. torch::jit::script::Module unet_module; try { unet_module = torch::jit::load("G:\liu_projects\unet_cpp\Unet_package\model\traced_unet_model.pt"); } catch (const c10::Error& e) { std::cerr << "error loading the model!"; return -1; } torch::Device device(torch::kCUDA); unet_module.to(device); unet_module.eval(); std::cout << "model loaded on cuda! "; //prepare image tensor. //std::vector<torch::jit::IValue> inputs; cv::Mat image; image = cv::imread("G:\liu_projects\unet_cpp\Unet_package\test_imgs\test.jpg"); torch::Tensor img_tensor = unet_data_preprocess(image, 0.5); img_tensor = img_tensor.to(device); //forward. at::Tensor output = unet_module.forward({ img_tensor }).toTensor(); at::Tensor probs = torch::sigmoid(output); probs = probs.squeeze(0).detach().permute({ 1, 2, 0 }); cout << "probs size: " << probs.sizes() << endl; probs = probs > 0.5; probs = probs.mul(255).clamp(0, 255).to(torch::kU8); probs = probs.to(torch::kCPU); //cv::Mat resultImg(640, 959, CV_8UC1); //// copy the data from out_tensor to resultImg //std::memcpy((void*)resultImg.data, probs.data_ptr(), sizeof(torch::kU8) * probs.numel()); cv::Mat resultImg(640, 959, CV_8UC1, (uchar*)probs.data_ptr()); cv::imshow("resultImg", resultImg); cv::waitKey(0); return 0; } int main() { run_unet(); return 0; }
参考文章:
https://www.cnblogs.com/yanghailin/p/12901586.html (libtorch 常用api函数示例(史上最全、最详细))
https://pytorch.apachecn.org/docs/1.4/30.html (记住:模型要保存成cpu的!)
https://blog.csdn.net/juluwangriyue/article/details/108360320 (libtorch tensor转mat)