#include <iostream>
#include <memory>
#include <string>
#include <torch/script.h>
#include <opencv2/opencv.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "opencv2/imgproc/types_c.h"
using namespace std;
using namespace cv;
torch::Tensor unet_data_preprocess(Mat &image, float scale=1) {
cv::cvtColor(image, image, CV_BGR2RGB);
int w = image.cols;
int h = image.rows;
int newW = int(scale * w);
int newH = int(scale * h);
Mat img_processed;
cv::resize(image, img_processed, cv::Size(newW, newH));
//cv::imshow("img_processed", img_processed);
//cv::waitKey(0);
torch::Tensor imgtransform;
imgtransform = torch::from_blob(img_processed.data, {1,newH,newW,3}, torch::kByte);
imgtransform = imgtransform.permute({0,3,1,2 });
imgtransform = imgtransform.to(torch::kFloat);
imgtransform = imgtransform.div(255.0);
return imgtransform;
}
int run_unet() {
//Load model.
torch::jit::script::Module unet_module;
try {
unet_module = torch::jit::load("G:\liu_projects\unet_cpp\Unet_package\model\traced_unet_model.pt");
}
catch (const c10::Error& e) {
std::cerr << "error loading the model!";
return -1;
}
torch::Device device(torch::kCUDA);
unet_module.to(device);
unet_module.eval();
std::cout << "model loaded on cuda!
";
//prepare image tensor.
//std::vector<torch::jit::IValue> inputs;
cv::Mat image;
image = cv::imread("G:\liu_projects\unet_cpp\Unet_package\test_imgs\test.jpg");
torch::Tensor img_tensor = unet_data_preprocess(image, 0.5);
img_tensor = img_tensor.to(device);
//forward.
at::Tensor output = unet_module.forward({ img_tensor }).toTensor();
at::Tensor probs = torch::sigmoid(output);
probs = probs.squeeze(0).detach().permute({ 1, 2, 0 });
cout << "probs size: " << probs.sizes() << endl;
probs = probs > 0.5;
probs = probs.mul(255).clamp(0, 255).to(torch::kU8);
probs = probs.to(torch::kCPU);
//cv::Mat resultImg(640, 959, CV_8UC1);
//// copy the data from out_tensor to resultImg
//std::memcpy((void*)resultImg.data, probs.data_ptr(), sizeof(torch::kU8) * probs.numel());
cv::Mat resultImg(640, 959, CV_8UC1, (uchar*)probs.data_ptr());
cv::imshow("resultImg", resultImg);
cv::waitKey(0);
return 0;
}
int main()
{
run_unet();
return 0;
}
参考文章:
https://www.cnblogs.com/yanghailin/p/12901586.html (libtorch 常用api函数示例(史上最全、最详细))
https://pytorch.apachecn.org/docs/1.4/30.html (记住:模型要保存成cpu的!)
https://blog.csdn.net/juluwangriyue/article/details/108360320 (libtorch tensor转mat)