zoukankan      html  css  js  c++  java
  • c++ 使用torchscript 加载训练好的pytorch模型

    1.首先官网上下载libtorch,放到当前项目下

    2.将pytorch训练好的模型使用torch.jit.trace导出为.pt格式

     1 import torch
     2 from skimage import io, transform, color
     3 import numpy as np
     4 import os
     5 import torch.nn.functional as F
     6 import warnings
     7 warnings.filterwarnings("ignore")
     8 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     9 
    10 labels = ['cock', 'drawing', 'neutral', 'porn', 'sexy']
    11 path = "test/n_1.jpg"
    12 im = io.imread(path)
    13 if im.shape[2] == 4:
    14     im = color.rgba2rgb(im)
    15 
    16 im = transform.resize(im, (224, 224))
    17 im = np.transpose(im, (2, 0, 1))
    18 dummy_input = np.expand_dims(im, 0)
    19 inp = torch.from_numpy(dummy_input)
    20 inp = inp.float()
    21 model = torch.load(
    22     "models/resnet50-epoch-0-accu-0.9213857428381079.pth", map_location='cpu')
    23 traced_script_module = torch.jit.trace(model, inp)
    24 output = model(inp)
    25 probs = F.softmax(output).detach().numpy()[0]
    26 pred = np.argmax(probs)
    27 
    28 traced_script_module.save("models/traced_resnet_model.pt")

    torchscript加载.pt模型

      1 // One-stop header.
      2 #include <torch/script.h>
      3 
      4 // headers for opencv
      5 #include <opencv2/highgui/highgui.hpp>
      6 #include <opencv2/imgproc/imgproc.hpp>
      7 #include <opencv2/opencv.hpp>
      8 
      9 #include <cmath>
     10 #include <iostream>
     11 #include <memory>
     12 #include <string>
     13 #include <vector>
     14 
     15 #define kIMAGE_SIZE 224
     16 #define kCHANNELS 3
     17 #define kTOP_K 1 //print top k predicted results
     18 
     19 bool LoadImage(std::string file_name, cv::Mat &image)
     20 {
     21   image = cv::imread(file_name); // CV_8UC3
     22   if (image.empty() || !image.data)
     23   {
     24     return false;
     25   }
     26   cv::cvtColor(image, image, CV_BGR2RGB);
     27   // scale image to fit
     28   cv::Size scale(kIMAGE_SIZE, kIMAGE_SIZE);
     29   cv::resize(image, image, scale);
     30 
     31   // convert [unsigned int] to [float]
     32   image.convertTo(image, CV_32FC3,1.0/255);
     33 
     34   return true;
     35 }
     36 
     37 bool LoadImageNetLabel(std::string file_name,
     38                        std::vector<std::string> &labels)
     39 {
     40   std::ifstream ifs(file_name);
     41   if (!ifs)
     42   {
     43     return false;
     44   }
     45   std::string line;
     46   while (std::getline(ifs, line))
     47   {
     48     labels.push_back(line);
     49   }
     50   return true;
     51 }
     52 
     53 int main(int argc, const char *argv[])
     54 {
     55   if (argc != 3)
     56   {
     57     std::cerr << "Usage:classifier <path-to-exported-script-module>  <path-to-lable-file> " << std::endl;
     58     return -1;
     59   }
     60 
     61   //load model
     62   torch::jit::script::Module module = torch::jit::load(argv[1]);
     63   // to GPU
     64   // module->to(at::kCUDA);
     65   std::cout << "== ResNet50 loaded!
    ";
     66 
     67   //load labels(classes names)
     68   std::vector<std::string> labels;
     69   if (LoadImageNetLabel(argv[2], labels))
     70   {
     71     std::cout << "== Label loaded! Let's try it
    ";
     72   }
     73   else
     74   {
     75     std::cerr << "Please check your label file path." << std::endl;
     76     return -1;
     77   }
     78 
     79   std::string file_name = "";
     80   cv::Mat image;
     81   while (true)
     82   {
     83     std::cout << "== Input image path: [enter q to exit]" << std::endl;
     84     std::cin >> file_name;
     85     if (file_name == "Q" || file_name == "q")
     86     {
     87       break;
     88     }
     89     if (LoadImage(file_name, image))
     90     {
     91       //read image tensor
     92       auto input_tensor = torch::from_blob(
     93           image.data, {1, kIMAGE_SIZE, kIMAGE_SIZE, kCHANNELS});
     94       input_tensor = input_tensor.permute({0, 3, 1, 2});
     95       input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229);
     96       input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224);
     97       input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225);
     98       // to GPU
     99       // input_tensor = input_tensor.to(at::kCUDA);
    100 
    101       torch::Tensor out_tensor = module.forward({input_tensor}).toTensor();
    102 
    103       auto results = out_tensor.sort(-1, true);
    104       auto softmaxs = std::get<0>(results)[0].softmax(0);
    105       auto indexs = std::get<1>(results)[0];
    106 
    107       for (int i = 0; i < kTOP_K; ++i)
    108       {
    109         auto idx = indexs[i].item<int>();
    110         std::cout << "    ============= Top-" << i + 1 << " =============" << std::endl;
    111         std::cout << "    Label:  " << labels[idx] << std::endl;
    112         std::cout << "    With Probability:  "
    113                   << softmaxs[i].item<float>() * 100.0f << "%" << std::endl;
    114       }
    115     }
    116     else
    117     {
    118       std::cout << "Can't load the image, please check your path." << std::endl;
    119     }
    120   }
    121 }

    CMakeLists.txt编译

     1 cmake_minimum_required(VERSION 2.8)
     2 project(predict_demo)
     3 SET(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "-std=c++11 -O3")
     4 
     5 
     6 set(OpenCV_DIR  /home/buyizhiyou/opencv-3.4.4/build)
     7 find_package(OpenCV REQUIRED)
     8 find_package(Torch REQUIRED)
     9 
    10 
    11 # 添加头文件
    12 include_directories( ${OpenCV_INCLUDE_DIRS} )
    13 
    14 add_executable(resnet_demo resnet_demo.cpp)
    15 target_link_libraries(resnet_demo ${TORCH_LIBRARIES}  ${OpenCV_LIBS})
    16 set_property(TARGET resnet_demo PROPERTY CXX_STANDARD 11)

    运行

    ./resnet_demo   models/traced_resnet_model.pt  labels.txt
  • 相关阅读:
    小菜菜mysql练习50题解析——数据准备
    C语言(数据结构)——概述
    运行 jar
    Hive 语句
    java14 IO流缓冲区 input output
    java 14 IO流
    java 14 图片的读取和写入
    java 集合的基础2
    java 13 hashmao的entryset()
    java 13 集合的基础
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/11981288.html
Copyright © 2011-2022 走看看