1.首先官网上下载libtorch,放到当前项目下
2.将pytorch训练好的模型使用torch.jit.trace导出为.pt格式
1 import torch 2 from skimage import io, transform, color 3 import numpy as np 4 import os 5 import torch.nn.functional as F 6 import warnings 7 warnings.filterwarnings("ignore") 8 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 10 labels = ['cock', 'drawing', 'neutral', 'porn', 'sexy'] 11 path = "test/n_1.jpg" 12 im = io.imread(path) 13 if im.shape[2] == 4: 14 im = color.rgba2rgb(im) 15 16 im = transform.resize(im, (224, 224)) 17 im = np.transpose(im, (2, 0, 1)) 18 dummy_input = np.expand_dims(im, 0) 19 inp = torch.from_numpy(dummy_input) 20 inp = inp.float() 21 model = torch.load( 22 "models/resnet50-epoch-0-accu-0.9213857428381079.pth", map_location='cpu') 23 traced_script_module = torch.jit.trace(model, inp) 24 output = model(inp) 25 probs = F.softmax(output).detach().numpy()[0] 26 pred = np.argmax(probs) 27 28 traced_script_module.save("models/traced_resnet_model.pt")
torchscript加载.pt模型
1 // One-stop header. 2 #include <torch/script.h> 3 4 // headers for opencv 5 #include <opencv2/highgui/highgui.hpp> 6 #include <opencv2/imgproc/imgproc.hpp> 7 #include <opencv2/opencv.hpp> 8 9 #include <cmath> 10 #include <iostream> 11 #include <memory> 12 #include <string> 13 #include <vector> 14 15 #define kIMAGE_SIZE 224 16 #define kCHANNELS 3 17 #define kTOP_K 1 //print top k predicted results 18 19 bool LoadImage(std::string file_name, cv::Mat &image) 20 { 21 image = cv::imread(file_name); // CV_8UC3 22 if (image.empty() || !image.data) 23 { 24 return false; 25 } 26 cv::cvtColor(image, image, CV_BGR2RGB); 27 // scale image to fit 28 cv::Size scale(kIMAGE_SIZE, kIMAGE_SIZE); 29 cv::resize(image, image, scale); 30 31 // convert [unsigned int] to [float] 32 image.convertTo(image, CV_32FC3,1.0/255); 33 34 return true; 35 } 36 37 bool LoadImageNetLabel(std::string file_name, 38 std::vector<std::string> &labels) 39 { 40 std::ifstream ifs(file_name); 41 if (!ifs) 42 { 43 return false; 44 } 45 std::string line; 46 while (std::getline(ifs, line)) 47 { 48 labels.push_back(line); 49 } 50 return true; 51 } 52 53 int main(int argc, const char *argv[]) 54 { 55 if (argc != 3) 56 { 57 std::cerr << "Usage:classifier <path-to-exported-script-module> <path-to-lable-file> " << std::endl; 58 return -1; 59 } 60 61 //load model 62 torch::jit::script::Module module = torch::jit::load(argv[1]); 63 // to GPU 64 // module->to(at::kCUDA); 65 std::cout << "== ResNet50 loaded! "; 66 67 //load labels(classes names) 68 std::vector<std::string> labels; 69 if (LoadImageNetLabel(argv[2], labels)) 70 { 71 std::cout << "== Label loaded! Let's try it "; 72 } 73 else 74 { 75 std::cerr << "Please check your label file path." << std::endl; 76 return -1; 77 } 78 79 std::string file_name = ""; 80 cv::Mat image; 81 while (true) 82 { 83 std::cout << "== Input image path: [enter q to exit]" << std::endl; 84 std::cin >> file_name; 85 if (file_name == "Q" || file_name == "q") 86 { 87 break; 88 } 89 if (LoadImage(file_name, image)) 90 { 91 //read image tensor 92 auto input_tensor = torch::from_blob( 93 image.data, {1, kIMAGE_SIZE, kIMAGE_SIZE, kCHANNELS}); 94 input_tensor = input_tensor.permute({0, 3, 1, 2}); 95 input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229); 96 input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224); 97 input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225); 98 // to GPU 99 // input_tensor = input_tensor.to(at::kCUDA); 100 101 torch::Tensor out_tensor = module.forward({input_tensor}).toTensor(); 102 103 auto results = out_tensor.sort(-1, true); 104 auto softmaxs = std::get<0>(results)[0].softmax(0); 105 auto indexs = std::get<1>(results)[0]; 106 107 for (int i = 0; i < kTOP_K; ++i) 108 { 109 auto idx = indexs[i].item<int>(); 110 std::cout << " ============= Top-" << i + 1 << " =============" << std::endl; 111 std::cout << " Label: " << labels[idx] << std::endl; 112 std::cout << " With Probability: " 113 << softmaxs[i].item<float>() * 100.0f << "%" << std::endl; 114 } 115 } 116 else 117 { 118 std::cout << "Can't load the image, please check your path." << std::endl; 119 } 120 } 121 }
CMakeLists.txt编译
1 cmake_minimum_required(VERSION 2.8) 2 project(predict_demo) 3 SET(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "-std=c++11 -O3") 4 5 6 set(OpenCV_DIR /home/buyizhiyou/opencv-3.4.4/build) 7 find_package(OpenCV REQUIRED) 8 find_package(Torch REQUIRED) 9 10 11 # 添加头文件 12 include_directories( ${OpenCV_INCLUDE_DIRS} ) 13 14 add_executable(resnet_demo resnet_demo.cpp) 15 target_link_libraries(resnet_demo ${TORCH_LIBRARIES} ${OpenCV_LIBS}) 16 set_property(TARGET resnet_demo PROPERTY CXX_STANDARD 11)
运行
./resnet_demo models/traced_resnet_model.pt labels.txt