zoukankan      html  css  js  c++  java
  • 使用C++调用并部署pytorch模型

    1.背景(Background)

     上图显示了目前深度学习模型在生产环境中的方法,本文仅探讨如何部署pytorch模型!

    至于为什么要用C++调用pytorch模型,其目的在于:使用C++及多线程可以加快模型预测速度

    关于模型训练有两种方法,一种是直接使用C++编写训练代码,可以做到搭建完整的网络模型,但是无法使用迁移学习,而迁移学习是目前训练样本几乎都会用到的方法,另一种是使用python代码训练好模型,并使用JIT技术,将python模型导出为C++可调用的模型,这里具体介绍第二种。(个人觉得还可以采用一种方式,即将pytorch模型作为一种Web Service以供各种客户端调用)

    官方对TorchScript的介绍如下(https://pytorch.org/docs/master/jit.html#creating-torchscript-code):

    TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。用TorchScript编写的任何代码都可以从Python进程中保存并加载到没有Python依赖关系的进程中。
    我们提供了一些工具来增量地将模型从纯Python程序转换为能够独立于Python运行的TorchScript程序,例如,在一个独立的c++程序中。这使得使用熟悉的工具在PyTorch中培训模型,然后通过TorchScript将模型导出到生产环境中成为可能。在生产环境中,出于性能和多线程的原因,将模型作为Python程序运行不是一个好主意。

     首先,我们在官网下载适合于Windows的libtorch,因为稳定版出来了,所以可以直接拿来使用。有CPU版本的和GPU版本的,这里我都进行了测试,都是可以直接使用的,这里以CPU版本为例进行介绍:

     

    2.实验(Experiments)

    1.python环境下跑模型的推断代码 

    以ESRGAN的inference code(https://github.com/xinntao/ESRGAN)为例:

    环境:Windows10+Python3.5.2+Pytorch1.1

    Python packages: pip install numpy opencv-python

    直接run test,结果如下(我的版本有做一些改动,如增加FPS的计算等):

     2.将PyTorch模型转换为Torch Script

    第一个方法是tracing.该方法通过将样本输入到模型中一次来对该过程进行评估从而捕获模型结构.并记录该样本在模型中的flow.该方法适用于模型中很少使用控制flow的模型.

    第二个方法就是向模型添加显式注释,通知Torch Script编译器它可以直接解析和编译模型代码,受Torch Script语言强加的约束。

    • 利用Tracing将模型转换为Torch Script

    要通过tracing来将PyTorch模型转换为Torch脚本,必须将模型的实例以及样本输入传递给torch.jit.trace函数.

    这将生成一个torch.jit.ScriptModule对象,并在模块的forward方法中嵌入模型评估的跟踪:

    import torch
    import architecture as arch
    
    # An instance of your model.
    model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', 
                            mode='CNA', res_scale=1, upsample_mode='upconv')
    
    model.load_state_dict(torch.load('./models/RRDB_ESRGAN_x4.pth'), strict=True)
    model.eval()
    
    # An example input you would normally provide to your model's forward() method.
    example = torch.rand(64, 3, 3, 3)
    
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(model, example)
    output = traced_script_module(torch.ones(64, 3, 3, 3))
    traced_script_module.save("./models/RRDB_ESRGAN_x4_000.pt")
    
    # The traced ScriptModule can now be evaluated identically to a regular PyTorch module
    print(output)
    

    跟踪的ScriptModule可以与常规PyTorch模块进行相同的计算,结果如下(注意在最后,将ScriptModule序列化为一个文件.然后,C++就可以不依赖任何Python代码来执行该Script所对应的Pytorch模型.):

    (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN$ python model_jit_converter.py 
    tensor([[[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
              [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
              [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
              ...,
              [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
              [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
              [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],
    
             [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
              [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
              [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
              ...,
              [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
              [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
              [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],
    
             [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
              [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
              [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
              ...,
              [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
              [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
              [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],
    
    
            [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
              [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
              [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
              ...,
              [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
              [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
              [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],
    
             [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
              [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
              [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
              ...,
              [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
              [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
              [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],
    
             [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
              [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
              [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
              ...,
              [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
              [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
              [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],
    
    
            [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
              [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
              [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
              ...,
              [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
              [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
              [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],
    
             [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
              [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
              [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
              ...,
              [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
              [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
              [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],
    
             [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
              [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
              [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
              ...,
              [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
              [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
              [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],
    
    
            ...,
    
    
            [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
              [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
              [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
              ...,
              [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
              [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
              [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],
    
             [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
              [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
              [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
              ...,
              [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
              [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
              [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],
    
             [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
              [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
              [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
              ...,
              [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
              [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
              [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],
    
    
            [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
              [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
              [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
              ...,
              [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
              [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
              [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],
    
             [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
              [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
              [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
              ...,
              [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
              [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
              [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],
    
             [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
              [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
              [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
              ...,
              [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
              [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
              [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],
    
    
            [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
              [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
              [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
              ...,
              [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
              [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
              [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],
    
             [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
              [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
              [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
              ...,
              [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
              [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
              [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],
    
             [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
              [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
              [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
              ...,
              [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
              [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
              [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]]],
           grad_fn=<MkldnnConvolutionBackward>)
    

     3.在C++中加载你的Script Module

    要在C ++中加载序列化的PyTorch模型,您的应用程序必须依赖于PyTorch C ++ API - 也称为LibTorch。LibTorch发行版包含一组共享库,头文件和CMake构建配置文件。虽然CMake不是依赖LibTorch的要求,但它是推荐的方法,并且将来会得到很好的支持。在本教程中,我们将使用CMake和LibTorch构建一个最小的C ++应用程序,它只需加载并执行序列化的PyTorch模型。

    加载模块的代码:

    #include <torch/script.h> // One-stop header.
    #include <iostream>
    #include <memory>
    
    int main(int argc, const char* argv[]) {
      if (argc != 2) {
        std::cerr << "usage: example-app <path-to-exported-script-module>
    ";
        return -1;
      }
    
      // Deserialize the ScriptModule from a file using torch::jit::load().
      std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
    
      assert(module != nullptr);
      std::cout << "ok
    ";
    }
    

    <torch / script.h>头文件包含运行该示例所需的LibTorch库中的所有相关包含。我们的应用程序接受序列化PyTorch ScriptModule的文件路径作为其唯一的命令行参数,然后使用torch :: jit :: load()函数继续反序列化模块,该函数将此文件路径作为输入。作为回报,我们收到一个指向torch :: jit :: script :: Module的共享指针,相当于C ++中的torch.jit.ScriptModule。目前,我们只验证此指针不为null。我们将研究如何在接下来执行它。

    LibTorch和构建应用程序

    假设我们将上面的代码保存到名为example-app.cpp的文件中。构建它的最小CMakeLists.txt如下:

    cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
    project(custom_ops)
    
    find_package(Torch REQUIRED)
    
    add_executable(example-app example-app.cpp)
    target_link_libraries(example-app "${TORCH_LIBRARIES}")
    set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
    

    构建应用程序时,假设我们的示例目录布局如下:

    example-app/
      CMakeLists.txt
      example-app.cpp
    

    现在可以运行以下命令从example-app/文件夹中构建应用程序:

    cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch
    make
    

    如果一切顺利,它将看起来像这样:

    (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch
    -- The C compiler identification is GNU 5.4.0
    -- The CXX compiler identification is GNU 5.4.0
    -- Check for working C compiler: /usr/bin/cc
    -- Check for working C compiler: /usr/bin/cc -- works
    -- Detecting C compiler ABI info
    -- Detecting C compiler ABI info - done
    -- Detecting C compile features
    -- Detecting C compile features - done
    -- Check for working CXX compiler: /usr/bin/c++
    -- Check for working CXX compiler: /usr/bin/c++ -- works
    -- Detecting CXX compiler ABI info
    -- Detecting CXX compiler ABI info - done
    -- Detecting CXX compile features
    -- Detecting CXX compile features - done
    -- Looking for pthread.h
    -- Looking for pthread.h - found
    -- Looking for pthread_create
    -- Looking for pthread_create - not found
    -- Looking for pthread_create in pthreads
    -- Looking for pthread_create in pthreads - not found
    -- Looking for pthread_create in pthread
    -- Looking for pthread_create in pthread - found
    -- Found Threads: TRUE  
    -- Found CUDA: /usr/local/cuda (found version "9.0") 
    -- Caffe2: CUDA detected: 9.0
    -- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc
    -- Caffe2: CUDA toolkit directory: /usr/local/cuda
    -- Caffe2: Header version is: 9.0
    -- Found CUDNN: /usr/include  
    -- Found cuDNN: v7.4.1  (include: /usr/include, library: /usr/lib/x86_64-linux-gnu/libcudnn.so)
    -- Autodetected CUDA architecture(s):  6.1
    -- Added CUDA NVCC flags for: -gencode;arch=compute_61,code=sm_61
    -- Found torch: /home/anpi-cn/workspace_min/libtorch/lib/libtorch.so  
    -- Configuring done
    CMake Warning at CMakeLists.txt:6 (add_executable):
      Cannot generate a safe runtime search path for target example-app because
      there is a cycle in the constraint graph:
    
        dir 0 is [/home/anpi-cn/workspace_min/libtorch/lib]
        dir 1 is [/usr/local/cuda/lib64/stubs]
        dir 2 is [/home/anpi-cn/.conda/envs/surper-resolution-pytorch/lib]
          dir 3 must precede it due to runtime library [libcudart.so.9.0]
        dir 3 is [/usr/local/cuda/lib64]
          dir 2 must precede it due to runtime library [libnvrtc.so.9.0]
    
      Some of these libraries may not be found correctly.
    
    
    -- Generating done
    -- Build files have been written to: /home/anpi-cn/workspace_min/Super-Resolution/ESRGAN/example-app
    (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ make
    Scanning dependencies of target example-app
    [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
    [100%] Linking CXX executable example-app
    [100%] Built target example-app
    (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ ./example-app ../models/RRDB_ESRGAN_x4_000.pt 
    ok
    

    4.在C++代码中执行Script Module

    在C ++中成功加载了我们的序列化模型后,添加以下代码到C ++应用程序的main()函数中:

    // Create a vector of inputs.
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({64, 3, 3, 3}));
    
    // Execute the model and turn its output into a tensor.
    auto output = module->forward(inputs).toTensor();
    
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '
    ';
    

     前两行设置了我们模型的输入。我们创建了一个torch :: jit :: IValue的向量并添加一个输入。要创建输入张量,我们使用torch :: ones(),相当于C ++ API中的torch.ones。然后我们运行script::Moduleforward方法,将它传递给我们创建的输入向量。作为回报,我们得到一个新的IValue,我们通过调用toTensor()将其转换为张量。

    在最后一行中,我们打印输出的前五个条目。由于在前面的Python中为本次的模型提供了相同的输入,因此理想情况下应该看到相同的输出。重新编译上面的应用程序并使用相同的序列化模型运行它来尝试。通过比较,发现C++的输出与Python的输出是一样的,表明实验成功啦!

    参考文章:

    C++调用Python

    https://pytorch.org/tutorials/advanced/cpp_export.html

    PyTorch 1.0 中文官方教程:使用 PyTorch C++ 前端

    利用Pytorch的C++前端(libtorch)读取预训练权重并进行预测

    https://github.com/ahkarami/Deep-Learning-in-Production

    https://zhuanlan.zhihu.com/p/52806730

  • 相关阅读:
    CURL POST提交json类型字符串数据和伪造IP和来源
    windows下nginx的配置
    常用JS兼容问题工具
    无限级分类--Array写法
    JS获取对象指定属性在样式中的信息
    解决IE和Firefox获取来源网址Referer的JS方法
    异步轮询函数
    响应式布局--特殊设备检测
    jQuery Validate校验
    [LeetCode#124]Binary Tree Maximum Path Sum
  • 原文地址:https://www.cnblogs.com/carsonzhu/p/11197048.html
Copyright © 2011-2022 走看看